"""
Sarvam AI Speech-to-Text provider.

Uses the `speech_to_text_translate_job` batch API (model saaras:v2.5) with
speaker diarisation.  Audio is downloaded to a temp file, uploaded to Sarvam,
polled until complete, and the result JSON is parsed into an STTResult.
"""
from __future__ import annotations

import logging
import os
import tempfile
import time
from typing import Any, Dict, List

import requests

from .base import BaseSTT, STTResult

logger = logging.getLogger(__name__)

_MAX_UPLOAD_RETRIES = 5
_POLL_INTERVAL_S = 3
_JOB_TIMEOUT_S = 300


def _parse_sarvam_result(result: Dict[str, Any]) -> STTResult:
    """Convert raw Sarvam JSON output to a normalised STTResult."""
    transcript_text: str = result.get("transcript") or ""

    diarized = result.get("diarized_transcript") or {}
    entries: List[Dict] = diarized.get("entries", []) if isinstance(diarized, dict) else []

    speaker_segments: List[Dict[str, Any]] = []
    for entry in entries:
        raw_id = str(entry.get("speaker_id", "0"))
        if raw_id.startswith("speaker_"):
            speaker_id = raw_id
            num = raw_id.replace("speaker_", "")
        else:
            speaker_id = "speaker_" + raw_id
            num = raw_id

        text = (entry.get("transcript") or "").strip()
        if not text:
            continue

        start = float(entry.get("start_time_seconds") or 0)
        end = float(entry.get("end_time_seconds") or 0)
        speaker_segments.append({
            "speaker": f"Speaker {num}",
            "speaker_id": speaker_id,
            "text": text,
            "start": start,
            "end": end,
            "start_time": start,
            "end_time": end,
        })

    # Assign roles: speaker with the lowest numeric ID is the agent.
    # Sarvam may use 0-indexed (speaker_0, speaker_1) or 1-indexed (speaker_1, speaker_2).
    if speaker_segments:
        nums = []
        for seg in speaker_segments:
            try:
                nums.append(int(seg["speaker_id"].replace("speaker_", "")))
            except ValueError:
                pass
        min_num = min(nums) if nums else 0
        for seg in speaker_segments:
            try:
                n = int(seg["speaker_id"].replace("speaker_", ""))
            except ValueError:
                n = -1
            seg["role"] = "agent" if n == min_num else "customer"

    duration = speaker_segments[-1]["end"] if speaker_segments else 0.0
    return STTResult(
        transcript=transcript_text,
        speaker_segments=speaker_segments,
        duration=duration,
        provider="sarvam",
    )


class SarvamSTT(BaseSTT):
    """Sarvam AI batch transcription provider."""

    def __init__(self, api_key: str):
        if not api_key:
            raise ValueError("SARVAM_SUBSCRIPTION_KEY is required")
        self._api_key = api_key

    @property
    def provider_name(self) -> str:
        return "sarvam"

    def transcribe(self, audio_url: str, callid: str) -> STTResult:
        """Download *audio_url*, submit to Sarvam, wait for result."""
        from sarvamai import SarvamAI  # lazy import – not always installed

        client = SarvamAI(api_subscription_key=self._api_key)
        tmp_path: str | None = None

        try:
            # ── 1. Download audio ────────────────────────────────────────────
            logger.info("[%s] Downloading audio: %s", callid, audio_url)
            resp = requests.get(audio_url, timeout=60)
            if resp.status_code != 200:
                raise RuntimeError(
                    f"Audio download failed: HTTP {resp.status_code} for {audio_url}"
                )

            # Preserve the real file extension so Sarvam parses the format correctly.
            # OGG files saved as .wav get silently truncated by the Sarvam parser.
            url_path = audio_url.split("?")[0]  # strip query params
            ext = os.path.splitext(url_path)[-1].lower() or ".wav"
            if ext not in {".wav", ".ogg", ".mp3", ".mp4", ".m4a", ".flac", ".opus"}:
                ext = ".wav"

            with tempfile.NamedTemporaryFile(
                suffix=ext, delete=False, dir="/tmp"
            ) as tmp:
                tmp.write(resp.content)
                tmp_path = tmp.name

            # ── 2. Create job and upload (retry on rate limits and Azure 403 bug) ──
            job = None
            for attempt in range(_MAX_UPLOAD_RETRIES):
                try:
                    job = client.speech_to_text_translate_job.create_job(
                        model="saaras:v2.5",
                        with_diarization=True,
                        num_speakers=2,
                        prompt="Translate all speech to English",
                    )
                    ok = job.upload_files([tmp_path])
                    if ok:
                        break
                except Exception as exc:
                    exc_str = str(exc)
                    logger.debug("[%s] Caught exception type: %s, message: %s", callid, type(exc).__name__, exc_str)
                    # Handle Rate Limits (429)
                    if "429" in exc_str or (hasattr(exc, 'status_code') and exc.status_code == 429):
                        wait_time = (2 ** attempt) + 1
                        logger.warning(
                            "[%s] Rate limited (429). Attempt %d/%d. Waiting %ds...",
                            callid, attempt + 1, _MAX_UPLOAD_RETRIES, wait_time
                        )
                        time.sleep(wait_time)
                        job = None
                        continue
                        
                    # Handle Azure 403 bug
                    if "403" in exc_str and attempt < _MAX_UPLOAD_RETRIES - 1:
                        logger.warning(
                            "[%s] Upload attempt %d/%d failed (403), retrying…",
                            callid, attempt + 1, _MAX_UPLOAD_RETRIES
                        )
                        time.sleep(1)
                        job = None
                    else:
                        raise

            if not job:
                raise RuntimeError(f"All {_MAX_UPLOAD_RETRIES} upload attempts failed")

            # ── 3. Start job and wait ────────────────────────────────────────
            job.start()
            logger.info("[%s] Sarvam job started, waiting (max %ds)…", callid, _JOB_TIMEOUT_S)
            status = job.wait_until_complete(
                poll_interval=_POLL_INTERVAL_S, timeout=_JOB_TIMEOUT_S
            )

            if not job.is_successful():
                raise RuntimeError(f"Sarvam job did not succeed (status={status})")

            # ── 4. Retrieve result ───────────────────────────────────────────
            output_file = "0.json"
            if status.job_details:
                for detail in status.job_details:
                    if detail.state == "Success" and detail.outputs:
                        output_file = detail.outputs[0].file_name
                        break

            links = client.speech_to_text_translate_job.get_download_links(
                job_id=job.job_id, files=[output_file]
            )
            dl_url = None
            if links.download_urls and output_file in links.download_urls:
                dl_url = links.download_urls[output_file].file_url

            if not dl_url:
                raise RuntimeError("No download URL in Sarvam job result")

            result_resp = requests.get(dl_url, timeout=60)
            if result_resp.status_code != 200:
                raise RuntimeError(f"Result download failed: HTTP {result_resp.status_code}")

            logger.info("[%s] Transcription complete.", callid)
            return _parse_sarvam_result(result_resp.json())

        finally:
            if tmp_path and os.path.exists(tmp_path):
                os.unlink(tmp_path)
