from __future__ import annotations
import logging
import time
from typing import Any, Dict, List, Optional
from config.settings import settings
from db import raw_calls, stt_jobs, bid_config
from stt.factory import get_stt_provider

logger = logging.getLogger(__name__)


class TranscriptionWorker:
    def __init__(self):
        self.stt = get_stt_provider()
        self.batch_size = settings.batch_size
        self.max_retries = settings.max_retries
        self.bid_whitelist = settings.bid_whitelist
        self._ensure_bid_config_table()
        self._auto_register_bids()

    def _ensure_bid_config_table(self) -> None:
        try:
            bid_config.ensure_table()
        except Exception as exc:
            logger.warning("Could not ensure stt_pipeline_bid_config table: %s", exc)

    def _auto_register_bids(self) -> None:
        """Discover all *_raw_calls tables and insert disabled rows for unknown bids."""
        try:
            all_bids = raw_calls.get_all_bids()
            for bid in all_bids:
                bid_config.ensure_bid_registered(bid)
            if all_bids:
                logger.info("Auto-registered %d bid(s) in stt_pipeline_bid_config", len(all_bids))
        except Exception as exc:
            logger.warning("Auto-register bids failed: %s", exc)

    def run_forever(self):
        logger.info("Worker started | provider=%s | poll=%ds | batch=%d",
                    self.stt.provider_name, settings.poll_interval_seconds, self.batch_size)
        while True:
            try:
                self._run_one_cycle()
            except Exception as exc:
                logger.exception("Unexpected error in worker cycle: %s", exc)
            self._log_stats()
            logger.info("Sleeping %ds…", settings.poll_interval_seconds)
            time.sleep(settings.poll_interval_seconds)

    def run_once(self):
        self._run_one_cycle()
        self._log_stats()

    def _get_active_bid_configs(self) -> List[Dict[str, Any]]:
        """Return enabled bids from DB. Falls back to env BID_WHITELIST if DB fails."""
        try:
            enabled = bid_config.get_enabled_bids()
            if enabled:
                return enabled
            # Nothing enabled in DB yet — fall back to whitelist if set
            if self.bid_whitelist:
                logger.info("No DB-enabled bids found; using BID_WHITELIST fallback: %s",
                            self.bid_whitelist)
                return [
                    {"bid": b, "raw_calls_id_col": "id", "raw_calls_url_col": "recording_url",
                     "batch_size": self.batch_size}
                    for b in self.bid_whitelist
                ]
            return []
        except Exception as exc:
            logger.warning("DB bid config query failed (%s); falling back to BID_WHITELIST", exc)
            return [
                {"bid": b, "raw_calls_id_col": "id", "raw_calls_url_col": "recording_url",
                 "batch_size": self.batch_size}
                for b in (self.bid_whitelist or [])
            ]

    def _run_one_cycle(self):
        reset = stt_jobs.reset_stale_processing_jobs(stale_after_minutes=30)
        if reset:
            logger.warning("Reset %d stale processing jobs to pending", reset)

        active_configs = self._get_active_bid_configs()
        if not active_configs:
            logger.info("No enabled bids — nothing to discover this cycle")
        total_new = 0
        for cfg in active_configs:
            bid = cfg["bid"]
            id_col = cfg.get("raw_calls_id_col", "id") or "id"
            url_col = cfg.get("raw_calls_url_col", "recording_url") or "recording_url"
            batch = cfg.get("batch_size") or self.batch_size
            total_new += self._discover_new_calls(bid, id_col=id_col, url_col=url_col, batch=batch)
        if total_new:
            logger.info("Registered %d new call(s) as pending", total_new)

        for job in stt_jobs.get_pending_jobs(limit=self.batch_size):
            self._process_job(job)

        for job in stt_jobs.get_retryable_failed_jobs(max_retries=self.max_retries, limit=self.batch_size):
            stt_jobs._update_status(job["id"], stt_jobs.STATUS_PENDING)
            self._process_job(job)

    def _discover_new_calls(self, bid: str, id_col: str = "id",
                            url_col: str = "recording_url", batch: int = 0) -> int:
        limit = (batch or self.batch_size) * 5
        already_seen = stt_jobs.get_all_seen_call_ids(bid)
        new_calls = raw_calls.get_new_calls(
            bid=bid, already_seen_ids=already_seen, limit=limit,
            id_col=id_col, url_col=url_col,
        )
        count = 0
        for call in new_calls:
            if not call.get("recording_url"):
                continue
            metadata = {k: str(v) for k, v in call.items()
                        if k not in ("call_id", "recording_url") and v is not None}
            stt_jobs.insert_job(bid=bid, call_id=call["call_id"],
                                recording_url=call["recording_url"], metadata=metadata)
            count += 1
        return count

    def _process_job(self, job: Dict[str, Any]):
        job_id, bid, call_id = job["id"], job["bid"], job["call_id"]
        recording_url = job["recording_url"]
        logger.info("[%s/%s] Starting STT (job_id=%d)", bid, call_id, job_id)
        stt_jobs.mark_processing(job_id)
        try:
            result = self.stt.transcribe(audio_url=recording_url, callid=call_id)
            speaker_count = len({seg["speaker_id"] for seg in result.speaker_segments})
            stt_jobs.mark_done(job_id=job_id, transcript=result.transcript,
                               speaker_segments=result.speaker_segments,
                               speaker_count=speaker_count, duration=result.duration,
                               stt_provider=result.provider)
            logger.info("[%s/%s] Done — %.1fs, %d speakers, %d segments",
                        bid, call_id, result.duration, speaker_count, len(result.speaker_segments))
        except Exception as exc:
            logger.error("[%s/%s] Failed (retry %d/%d): %s",
                         bid, call_id, job.get("retry_count", 0) + 1, self.max_retries, exc)
            stt_jobs.mark_failed(job_id, error=str(exc), increment_retry=True)

    def _log_stats(self):
        try:
            stats = stt_jobs.get_job_stats()
            logger.info("Stats: %s", " | ".join(f"{s}={c}" for s, c in sorted(stats.items())))
        except Exception:
            pass
