"""
Sarvam AI Speech-to-Text provider.

Uses the `speech_to_text_translate_job` batch API (model saaras:v2.5) with
speaker diarisation.  Audio is downloaded to a temp file, uploaded to Sarvam,
polled until complete, and the result JSON is parsed into an STTResult.
"""
from __future__ import annotations

import logging
import os
import tempfile
import time
from typing import Any, Dict, List

import requests

from .base import BaseSTT, STTResult

logger = logging.getLogger(__name__)

_MAX_UPLOAD_RETRIES = 5
_POLL_INTERVAL_S = 3
_JOB_TIMEOUT_S = 300


def _speaker_numeric_id(speaker_id: str) -> int | None:
    try:
        return int(str(speaker_id).replace("speaker_", ""))
    except (TypeError, ValueError):
        return None


def _apply_display_speaker_labels(segments: List[Dict[str, Any]]) -> None:
    nums = [_speaker_numeric_id(seg.get("speaker_id", "")) for seg in segments]
    nums = [n for n in nums if n is not None]
    if not nums:
        return
    id_to_display = {n: i + 1 for i, n in enumerate(sorted(set(nums)))}
    for seg in segments:
        n = _speaker_numeric_id(seg.get("speaker_id", ""))
        if n is None:
            continue
        seg["speaker"] = f"Speaker {id_to_display[n]}"


def _parse_sarvam_result(result: Dict[str, Any]) -> STTResult:
    """Convert raw Sarvam JSON output to a normalised STTResult."""
    transcript_text: str = result.get("transcript") or ""

    diarized = result.get("diarized_transcript") or {}
    entries: List[Dict] = diarized.get("entries", []) if isinstance(diarized, dict) else []

    speaker_segments: List[Dict[str, Any]] = []
    for entry in entries:
        raw_id = str(entry.get("speaker_id", "0"))
        if raw_id.startswith("speaker_"):
            speaker_id = raw_id
            num = raw_id.replace("speaker_", "")
        else:
            speaker_id = "speaker_" + raw_id
            num = raw_id

        text = (entry.get("transcript") or "").strip()
        if not text:
            continue

        start = float(entry.get("start_time_seconds") or 0)
        end = float(entry.get("end_time_seconds") or 0)
        speaker_segments.append({
            "speaker": f"Speaker {num}",
            "speaker_id": speaker_id,
            "text": text,
            "start": start,
            "end": end,
            "start_time": start,
            "end_time": end,
        })

    # Assign roles: speaker with the lowest numeric ID is the agent.
    if speaker_segments:
        nums = [_speaker_numeric_id(seg["speaker_id"]) for seg in speaker_segments]
        nums = [n for n in nums if n is not None]
        min_num = min(nums) if nums else 0
        for seg in speaker_segments:
            n = _speaker_numeric_id(seg["speaker_id"])
            seg["role"] = "agent" if n == min_num else "customer"
        _apply_display_speaker_labels(speaker_segments)

    duration = speaker_segments[-1]["end"] if speaker_segments else 0.0
    return STTResult(
        transcript=transcript_text,
        speaker_segments=speaker_segments,
        duration=duration,
        provider="sarvam",
    )


class SarvamSTT(BaseSTT):
    """Sarvam AI batch transcription provider."""

    def __init__(self, api_key: str):
        if not api_key:
            raise ValueError("SARVAM_SUBSCRIPTION_KEY is required")
        self._api_key = api_key

    @property
    def provider_name(self) -> str:
        return "sarvam"

    def transcribe(self, audio_url: str, callid: str) -> STTResult:
        """Download *audio_url*, submit to Sarvam, wait for result."""
        from sarvamai import SarvamAI  # lazy import – not always installed

        client = SarvamAI(api_subscription_key=self._api_key)
        tmp_path: str | None = None

        try:
            # ── 1. Download audio ────────────────────────────────────────────
            logger.info("[%s] Downloading audio: %s", callid, audio_url)
            resp = requests.get(audio_url, timeout=60)
            if resp.status_code != 200:
                raise RuntimeError(
                    f"Audio download failed: HTTP {resp.status_code} for {audio_url}"
                )

            with tempfile.NamedTemporaryFile(
                suffix=".wav", delete=False, dir="/tmp"
            ) as tmp:
                tmp.write(resp.content)
                tmp_path = tmp.name

            # ── 2. Create job and upload (retry on Azure 403 bug) ────────────
            job = None
            for attempt in range(_MAX_UPLOAD_RETRIES):
                try:
                    job = client.speech_to_text_translate_job.create_job(
                        model="saaras:v2.5",
                        with_diarization=True,
                        num_speakers=2,
                        prompt="Translate all speech to English",
                    )
                    ok = job.upload_files([tmp_path])
                    if ok:
                        break
                except RuntimeError as exc:
                    if "403" in str(exc) and attempt < _MAX_UPLOAD_RETRIES - 1:
                        logger.warning(
                            "[%s] Upload attempt %d/5 failed (403), retrying…",
                            callid, attempt + 1,
                        )
                        time.sleep(1)
                        job = None
                    else:
                        raise

            if not job:
                raise RuntimeError(f"All {_MAX_UPLOAD_RETRIES} upload attempts failed")

            # ── 3. Start job and wait ────────────────────────────────────────
            job.start()
            logger.info("[%s] Sarvam job started, waiting (max %ds)…", callid, _JOB_TIMEOUT_S)
            status = job.wait_until_complete(
                poll_interval=_POLL_INTERVAL_S, timeout=_JOB_TIMEOUT_S
            )

            if not job.is_successful():
                raise RuntimeError(f"Sarvam job did not succeed (status={status})")

            # ── 4. Retrieve result ───────────────────────────────────────────
            output_file = "0.json"
            if status.job_details:
                for detail in status.job_details:
                    if detail.state == "Success" and detail.outputs:
                        output_file = detail.outputs[0].file_name
                        break

            links = client.speech_to_text_translate_job.get_download_links(
                job_id=job.job_id, files=[output_file]
            )
            dl_url = None
            if links.download_urls and output_file in links.download_urls:
                dl_url = links.download_urls[output_file].file_url

            if not dl_url:
                raise RuntimeError("No download URL in Sarvam job result")

            result_resp = requests.get(dl_url, timeout=60)
            if result_resp.status_code != 200:
                raise RuntimeError(f"Result download failed: HTTP {result_resp.status_code}")

            logger.info("[%s] Transcription complete.", callid)
            return _parse_sarvam_result(result_resp.json())

        finally:
            if tmp_path and os.path.exists(tmp_path):
                os.unlink(tmp_path)
