from __future__ import annotations
import logging
import os
import tempfile
import time
from typing import Any, Dict, List
import requests
from .base import BaseSTT, STTResult
from .concurrency import SarvamConcurrencySlot
from .audio_chunker import wav_duration_seconds

logger = logging.getLogger(__name__)

_MAX_UPLOAD_RETRIES = 5
_MAX_BATCH_RETRIES = int(os.getenv("SARVAM_BATCH_RETRIES", "5"))
_MAX_RATE_LIMIT_RETRIES = int(os.getenv("SARVAM_RATE_LIMIT_RETRIES", "8"))
_RATE_LIMIT_BACKOFF_S = float(os.getenv("SARVAM_RATE_LIMIT_BACKOFF_SECONDS", "8"))
_POLL_INTERVAL_S = int(os.getenv("SARVAM_POLL_INTERVAL_SECONDS", "10"))
_JOB_TIMEOUT_S = int(
    os.getenv("SARVAM_JOB_TIMEOUT_SECONDS")
    or os.getenv("SARVAM_JOB_TIMEOUT_S")
    or os.getenv("SARVAM_PENDING_TIMEOUT_S")
    or "900"
)
# batch = Sarvam batch API with speaker diarization (required for proper Speaker 1/2 labels)
_STT_MODE = os.getenv("SARVAM_STT_MODE", "batch").strip().lower()
_MIN_AUDIO_SECONDS = float(os.getenv("SARVAM_MIN_AUDIO_SECONDS", "0.5"))
_CHUNK_SECONDS = float(os.getenv("SARVAM_SYNC_CHUNK_SECONDS", "25"))
_MIN_SEGMENTS_LONG_CALL = int(os.getenv("SARVAM_MIN_DIARIZATION_SEGMENTS", "4"))
_ACCEPT_WEAK_DIARIZATION = os.getenv("SARVAM_ACCEPT_WEAK_DIARIZATION", "1").strip().lower() in (
    "1",
    "true",
    "yes",
)


class DiarizationQualityError(RuntimeError):
    """Raised when STT result looks like fake chunk-based labels, not real diarization."""


def _speaker_numeric_id(speaker_id: str) -> int | None:
    try:
        return int(str(speaker_id).replace("speaker_", ""))
    except (TypeError, ValueError):
        return None


def _apply_display_speaker_labels(segments: List[Dict[str, Any]]) -> None:
    """Show Speaker 1/2/… in the UI regardless of Sarvam's 0- or 1-based IDs."""
    nums = [_speaker_numeric_id(seg.get("speaker_id", "")) for seg in segments]
    nums = [n for n in nums if n is not None]
    if not nums:
        return
    id_to_display = {n: i + 1 for i, n in enumerate(sorted(set(nums)))}
    for seg in segments:
        n = _speaker_numeric_id(seg.get("speaker_id", ""))
        if n is None:
            continue
        seg["speaker"] = f"Speaker {id_to_display[n]}"


def _parse_sarvam_result(result: Dict[str, Any]) -> STTResult:
    transcript_text: str = result.get("transcript") or ""
    diarized = result.get("diarized_transcript") or {}
    entries: List[Dict] = diarized.get("entries", []) if isinstance(diarized, dict) else []
    speaker_segments = []
    for entry in entries:
        raw_id = str(entry.get("speaker_id", "0"))
        speaker_id = raw_id if raw_id.startswith("speaker_") else "speaker_" + raw_id
        num = speaker_id.replace("speaker_", "")
        text = (entry.get("transcript") or "").strip()
        if not text:
            continue
        start = float(entry.get("start_time_seconds") or 0)
        end = float(entry.get("end_time_seconds") or 0)
        speaker_segments.append({
            "speaker": f"Speaker {num}",
            "speaker_id": speaker_id,
            "text": text,
            "start": start,
            "end": end,
            "start_time": start,
            "end_time": end,
        })

    if speaker_segments:
        nums = [_speaker_numeric_id(seg["speaker_id"]) for seg in speaker_segments]
        nums = [n for n in nums if n is not None]
        min_num = min(nums) if nums else 0
        for seg in speaker_segments:
            n = _speaker_numeric_id(seg["speaker_id"])
            seg["role"] = "agent" if n == min_num else "customer"
        _apply_display_speaker_labels(speaker_segments)

    duration = speaker_segments[-1]["end"] if speaker_segments else 0.0
    return STTResult(
        transcript=transcript_text,
        speaker_segments=speaker_segments,
        duration=duration,
        provider="sarvam",
    )


def _is_rate_limit_error(exc: Exception) -> bool:
    msg = str(exc).lower()
    return (
        "429" in msg
        or "rate limit" in msg
        or "rate_limit" in msg
        or getattr(exc, "status_code", None) == 429
    )


def _looks_like_chunk_aligned_blocks(segments: List[Dict[str, Any]], duration: float) -> bool:
    """Detect fake diarization from 25s chunked sync (starts at 0, 25, 50… with long blocks)."""
    if duration <= _CHUNK_SECONDS or len(segments) < 2:
        return False
    expected_chunks = max(2, int(duration / _CHUNK_SECONDS + 0.999))
    if len(segments) > expected_chunks + 1:
        return False
    aligned = 0
    long_blocks = 0
    for seg in segments:
        start = float(seg.get("start") or seg.get("start_time") or 0)
        end = float(seg.get("end") or seg.get("end_time") or start)
        block_len = end - start
        if abs(start % _CHUNK_SECONDS) < 1.0 or abs(start - round(start / _CHUNK_SECONDS) * _CHUNK_SECONDS) < 1.0:
            aligned += 1
        if block_len >= _CHUNK_SECONDS - 3:
            long_blocks += 1
    return aligned >= len(segments) * 0.75 and long_blocks >= len(segments) * 0.5


def validate_diarization(result: STTResult, audio_duration: float, callid: str) -> None:
    """Reject transcripts that lack real Sarvam batch diarization."""
    segments = result.speaker_segments or []
    if not (result.transcript or "").strip():
        raise DiarizationQualityError(f"[{callid}] Empty transcript")

    if audio_duration < _MIN_AUDIO_SECONDS:
        return

    if not segments:
        raise DiarizationQualityError(
            f"[{callid}] No speaker segments (duration={audio_duration:.1f}s)"
        )

    if _looks_like_chunk_aligned_blocks(segments, audio_duration):
        raise DiarizationQualityError(
            f"[{callid}] Chunk-aligned fake diarization ({len(segments)} blocks, "
            f"duration={audio_duration:.1f}s)"
        )

    if audio_duration > 45 and len(segments) < _MIN_SEGMENTS_LONG_CALL:
        raise DiarizationQualityError(
            f"[{callid}] Too few speaker segments ({len(segments)}) for "
            f"{audio_duration:.1f}s audio — likely missing diarization"
        )

    speaker_ids = {_speaker_numeric_id(s.get("speaker_id", "")) for s in segments}
    speaker_ids.discard(None)
    if audio_duration > 30 and len(speaker_ids) < 2:
        raise DiarizationQualityError(
            f"[{callid}] Only one speaker detected for {audio_duration:.1f}s call"
        )


class SarvamSTT(BaseSTT):
    def __init__(self, api_key: str):
        if not api_key:
            raise ValueError("SARVAM_SUBSCRIPTION_KEY is required")
        self._api_key = api_key

    @property
    def provider_name(self) -> str:
        return "sarvam"

    def transcribe(self, audio_url: str, callid: str) -> STTResult:
        from sarvamai import SarvamAI
        client = SarvamAI(api_subscription_key=self._api_key)
        with SarvamConcurrencySlot(callid):
            return self._transcribe_locked(client, audio_url, callid)

    def _transcribe_batch_once(self, client, tmp_path: str, callid: str) -> STTResult:
        job = None
        for attempt in range(_MAX_UPLOAD_RETRIES):
            try:
                job = client.speech_to_text_translate_job.create_job(
                    model="saaras:v2.5",
                    with_diarization=True,
                    num_speakers=2,
                    prompt="Translate all speech to English",
                )
                if job.upload_files([tmp_path]):
                    break
            except RuntimeError as exc:
                if "403" in str(exc) and attempt < _MAX_UPLOAD_RETRIES - 1:
                    logger.warning("[%s] Upload attempt %d failed (403), retrying…", callid, attempt + 1)
                    time.sleep(1)
                    job = None
                else:
                    raise
        if not job:
            raise RuntimeError(f"All {_MAX_UPLOAD_RETRIES} upload attempts failed")

        job.start()
        logger.info("[%s] Sarvam batch job %s started (diarization=on), waiting…", callid, job.job_id)
        status = job.wait_until_complete(poll_interval=_POLL_INTERVAL_S, timeout=_JOB_TIMEOUT_S)
        if not job.is_successful():
            state = getattr(status, "job_state", None) or getattr(status, "state", None) or status
            raise RuntimeError(f"Sarvam job did not succeed (status={state})")

        output_file = "0.json"
        if status.job_details:
            for detail in status.job_details:
                if detail.state == "Success" and detail.outputs:
                    output_file = detail.outputs[0].file_name
                    break
        links = client.speech_to_text_translate_job.get_download_links(
            job_id=job.job_id, files=[output_file]
        )
        dl_url = None
        if links.download_urls and output_file in links.download_urls:
            dl_url = links.download_urls[output_file].file_url
        if not dl_url:
            raise RuntimeError("No download URL in Sarvam job result")

        result_resp = requests.get(dl_url, timeout=120)
        if result_resp.status_code != 200:
            raise RuntimeError(f"Result download failed: HTTP {result_resp.status_code}")
        logger.info("[%s] Batch transcription download complete.", callid)
        return _parse_sarvam_result(result_resp.json())

    def _transcribe_batch(self, client, tmp_path: str, callid: str, audio_duration: float) -> STTResult:
        last_exc: Exception | None = None
        attempts = max(_MAX_BATCH_RETRIES, _MAX_RATE_LIMIT_RETRIES)
        for attempt in range(attempts):
            try:
                result = self._transcribe_batch_once(client, tmp_path, callid)
                try:
                    validate_diarization(result, audio_duration, callid)
                except DiarizationQualityError as diar_exc:
                    if _ACCEPT_WEAK_DIARIZATION and (result.transcript or "").strip():
                        logger.warning(
                            "[%s] Weak diarization accepted (%.1fs audio, %s segments): %s",
                            callid,
                            audio_duration,
                            len(result.speaker_segments or []),
                            diar_exc,
                        )
                    else:
                        raise
                else:
                    logger.info(
                        "[%s] Diarization OK: %s segments, %.1fs, %s speaker(s)",
                        callid,
                        len(result.speaker_segments),
                        result.duration or audio_duration,
                        len({_speaker_numeric_id(s.get("speaker_id", "")) for s in result.speaker_segments} - {None}),
                    )
                return result
            except DiarizationQualityError:
                raise
            except Exception as exc:
                last_exc = exc
                if _is_rate_limit_error(exc) and attempt < attempts - 1:
                    wait_s = min(90.0, _RATE_LIMIT_BACKOFF_S * (2 ** attempt))
                    logger.warning(
                        "[%s] Batch rate limit (attempt %s/%s), retrying in %.0fs…",
                        callid,
                        attempt + 1,
                        attempts,
                        wait_s,
                    )
                    time.sleep(wait_s)
                    continue
                if attempt < attempts - 1:
                    wait_s = min(30.0, 3 * (attempt + 1))
                    logger.warning(
                        "[%s] Batch failed (attempt %s/%s): %s — retry in %.0fs",
                        callid,
                        attempt + 1,
                        attempts,
                        exc,
                        wait_s,
                    )
                    time.sleep(wait_s)
                    continue
                raise
        if last_exc:
            raise last_exc
        raise RuntimeError("Sarvam batch transcription failed")

    def _transcribe_locked(self, client, audio_url: str, callid: str) -> STTResult:
        tmp_path = None
        try:
            logger.info("[%s] Sarvam STT mode=%s (diarization via batch API)", callid, _STT_MODE)
            logger.info("[%s] Downloading audio: %s", callid, audio_url)
            resp = requests.get(audio_url, timeout=120)
            if resp.status_code != 200:
                raise RuntimeError(f"Audio download failed: HTTP {resp.status_code}")
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/tmp") as tmp:
                tmp.write(resp.content)
                tmp_path = tmp.name

            duration = wav_duration_seconds(tmp_path)
            if duration < _MIN_AUDIO_SECONDS:
                raise RuntimeError(f"Audio too short or empty ({duration:.2f}s)")

            if _STT_MODE != "batch":
                logger.warning(
                    "[%s] SARVAM_STT_MODE=%s ignored — batch+diarization is required",
                    callid,
                    _STT_MODE,
                )
            return self._transcribe_batch(client, tmp_path, callid, duration)
        finally:
            if tmp_path and os.path.exists(tmp_path):
                os.unlink(tmp_path)
