"""Per-BID minimum call duration gate for STT workers."""

from __future__ import annotations

from datetime import datetime
from typing import Any, Dict, Optional, Tuple

from db.connection import get_connection


def _parse_dt(value) -> Optional[datetime]:
    if value is None:
        return None
    if isinstance(value, datetime):
        return value.replace(tzinfo=None) if value.tzinfo else value
    try:
        parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
        return parsed.replace(tzinfo=None) if parsed.tzinfo else parsed
    except Exception:
        return None


def load_min_duration_config(bid: str) -> Dict[str, Any]:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                """
                SELECT min_call_duration_s, min_call_duration_effective_at
                FROM business_pipeline_config
                WHERE bid = %s
                LIMIT 1
                """,
                (str(bid),),
            )
            row = cur.fetchone() or {}
    return {
        "min_call_duration_s": max(0, int(row.get("min_call_duration_s") or 0)),
        "min_call_duration_effective_at": row.get("min_call_duration_effective_at"),
    }


def min_duration_applies(call_starttime, effective_at) -> bool:
    eff = _parse_dt(effective_at)
    if eff is None:
        return False
    start = _parse_dt(call_starttime)
    if start is None:
        return True
    return start >= eff


def should_skip_stt(
    bid: str,
    call_id: str,
    recording_url: str,
) -> Tuple[bool, str, Optional[float]]:
    """
    Return (skip, reason, probed_seconds).
    Skips before Sarvam when probed audio is below the configured minimum.
    """
    from stt.recording_probe import probe_wav_duration_from_url

    cfg = load_min_duration_config(bid)
    min_s = int(cfg.get("min_call_duration_s") or 0)
    if min_s <= 0:
        return False, "", None

    call_starttime = None
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                f"SELECT call_starttime FROM `{bid}_raw_calls` WHERE callid = %s LIMIT 1",
                (str(call_id),),
            )
            row = cur.fetchone() or {}
            call_starttime = row.get("call_starttime")

    if not min_duration_applies(call_starttime, cfg.get("min_call_duration_effective_at")):
        return False, "", None

    probed = probe_wav_duration_from_url(recording_url)
    if probed is None:
        return False, "", None
    if probed < min_s:
        return True, f"actual audio {probed:.1f}s < min {min_s}s", probed
    return False, "", probed
