from __future__ import annotations
import json
import logging
import pika
from typing import Dict, Any
import time

from config.settings import settings
from db.connection import get_connection
from db import stt_jobs
from stt.factory import get_stt_provider
from stt.concurrency import SarvamSlotTimeoutError

logger = logging.getLogger(__name__)


class RabbitMQTranscriptionWorker:

    def __init__(self):

        self.stt = get_stt_provider()
        self.max_retries = settings.max_retries
        self.connection = None
        self.channel = None
        self.queue_name = settings.rabbitmq_queue

    def _connect(self):
        self._close()
        self.connection = pika.BlockingConnection(
            pika.ConnectionParameters(
                host=settings.rabbitmq_host,
                port=settings.rabbitmq_port,
                credentials=pika.PlainCredentials(
                    settings.rabbitmq_user,
                    settings.rabbitmq_password
                ),
                heartbeat=600,
            )
        )
        self.channel = self.connection.channel()
        self.channel.queue_declare(queue=self.queue_name, durable=True)
        logger.info(
            "RabbitMQ worker connected | queue=%s | provider=%s",
            self.queue_name,
            self.stt.provider_name
        )

    def _close(self):
        try:
            if self.channel and self.channel.is_open:
                self.channel.close()
        except Exception:
            pass
        try:
            if self.connection and self.connection.is_open:
                self.connection.close()
        except Exception:
            pass
        self.channel = None
        self.connection = None

    # ---------------------------------------------------

    def start(self):
        while True:
            try:
                self._connect()
                self.channel.basic_qos(prefetch_count=1)
                self.channel.basic_consume(
                    queue=self.queue_name,
                    on_message_callback=self._callback
                )
                logger.info("Waiting for messages...")
                self.channel.start_consuming()
            except KeyboardInterrupt:
                logger.info("Worker shutting down")
                break
            except Exception as exc:
                logger.exception("Worker connection lost, reconnecting in 5s: %s", exc)
                self._close()
                time.sleep(5)
        self._close()

    # ---------------------------------------------------

    def _callback(self, ch, method, properties, body):

        job = None
        try:
            job = json.loads(body.decode())
            logger.info(
                "[%s/%s] Received job",
                job.get("bid"),
                job.get("call_id")
            )
            # Ack before STT: Sarvam can run up to SARVAM_JOB_TIMEOUT_S (40 min) but
            # RabbitMQ closes the channel if delivery is unacked for ~30 min.
            ch.basic_ack(delivery_tag=method.delivery_tag)
        except Exception as exc:
            logger.exception("Failed to parse/ack message: %s", exc)
            try:
                ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
            except Exception:
                pass
            return

        try:
            self._process_job(job)
        except Exception as exc:
            logger.exception(
                "[%s/%s] Job failed after ack: %s",
                job.get("bid"),
                job.get("call_id"),
                exc,
            )

    # ---------------------------------------------------

    def _raw_table(self, bid: str) -> str:
        return f"`{bid}_raw_calls`"

    def _response_table(self, bid: str) -> str:
        return f"`{bid}_sarvamresponse`"

    def _analytics_table(self, bid: str) -> str:
        return f"`{bid}_callanalytics`"

    def _is_webhook_ingest_enabled(self, bid: str) -> bool:
        try:
            with get_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute(
                        """
                        SELECT webhook_ingest_enabled
                        FROM business_pipeline_config
                        WHERE bid = %s
                        LIMIT 1
                        """,
                        (bid,),
                    )
                    row = cur.fetchone()
                    return bool(int((row or {}).get("webhook_ingest_enabled") or 0))
        except Exception as exc:
            logger.warning("[%s] Could not read webhook_ingest_enabled: %s", bid, exc)
            return False

    def _publish_analytics_job(self, bid: str, call_id: str) -> None:
        payload = {"bid": bid, "call_id": call_id}
        queue = settings.rabbitmq_analytics_queue
        try:
            conn = pika.BlockingConnection(
                pika.ConnectionParameters(
                    host=settings.rabbitmq_host,
                    port=settings.rabbitmq_port,
                    credentials=pika.PlainCredentials(
                        settings.rabbitmq_user,
                        settings.rabbitmq_password,
                    ),
                    heartbeat=600,
                )
            )
            channel = conn.channel()
            channel.queue_declare(queue=queue, durable=True)
            channel.basic_publish(
                exchange="",
                routing_key=queue,
                body=json.dumps(payload),
                properties=pika.BasicProperties(delivery_mode=2),
            )
            conn.close()
            logger.info("[%s/%s] Published analytics job to %s", bid, call_id, queue)
        except Exception as exc:
            logger.warning("[%s/%s] Failed to publish analytics job: %s", bid, call_id, exc)

    def _should_enqueue_analytics(self, bid: str) -> bool:
        if settings.stt_enqueue_analytics_always:
            return True
        return self._is_webhook_ingest_enabled(bid)

    def _maybe_enqueue_analytics(self, bid: str, call_id: str, raw_status: int) -> None:
        if int(raw_status) != 2:
            return
        if not self._should_enqueue_analytics(bid):
            return
        self._publish_analytics_job(bid, call_id)

    def _existing_result_status(self, cur, bid: str, call_id: str):
        """Return the correct raw_calls status when transcript already exists."""
        resp_table = self._response_table(bid)
        analytics_table = self._analytics_table(bid)

        cur.execute(
            f"""
            SELECT 1
            FROM {resp_table}
            WHERE callid = %s
              AND transcript IS NOT NULL
              AND transcript != ''
            LIMIT 1
            """,
            (call_id,),
        )
        if not cur.fetchone():
            return None

        cur.execute(
            f"SELECT 1 FROM {analytics_table} WHERE callid = %s LIMIT 1",
            (call_id,),
        )
        return 3 if cur.fetchone() else 2

    def _set_raw_status(self, cur, bid: str, call_id: str, status: int):
        raw_table = self._raw_table(bid)
        cur.execute(
            f"""
            UPDATE {raw_table}
            SET status = %s,
                transcription_status = CASE
                    WHEN %s IN (2, 3) THEN 'completed'
                    ELSE transcription_status
                END
            WHERE callid = %s
            """,
            (status, status, call_id),
        )

    def _purge_empty_transcript_rows(self, cur, bid: str, call_id: str) -> int:
        resp_table = self._response_table(bid)
        cur.execute(
            f"""
            DELETE FROM {resp_table}
            WHERE callid = %s
              AND (transcript IS NULL OR TRIM(transcript) = '')
            """,
            (call_id,),
        )
        return int(cur.rowcount or 0)

    def _is_duplicate_key_error(self, exc: Exception) -> bool:
        return getattr(exc, "args", [None])[0] == 1062 or "Duplicate entry" in str(exc)

    def _is_unknown_call_starttime_error(self, exc: Exception) -> bool:
        return (
            getattr(exc, "args", [None])[0] == 1054
            and "call_starttime" in str(exc)
        )

    def _parse_stt_retry_count(self, transcription_status) -> int:
        ts = str(transcription_status or "")
        if ts.startswith("stt_retry:"):
            try:
                return int(ts.split(":", 1)[1])
            except (TypeError, ValueError):
                return 0
        return 0

    def _is_recording_not_ready_error(self, error_message: str) -> bool:
        msg = error_message.lower()
        return "audio download failed: http 404" in msg or "download failed: http 404" in msg

    def _is_permanent_stt_failure(self, error_message: str) -> bool:
        if self._is_recording_not_ready_error(error_message):
            return False
        msg = error_message.lower()
        permanent_markers = (
            "audio too short",
            "audio download failed: http 403",
            "audio download failed: http 410",
            "empty transcript",
            "invalid audio",
            "corrupt",
            "file format",
            "unsupported",
        )
        return any(marker in msg for marker in permanent_markers)

    def _handle_stt_failure(self, cur, bid: str, call_id: str, error_message: str) -> None:
        raw_table = self._raw_table(bid)
        existing_status = self._existing_result_status(cur, bid, call_id)
        if existing_status is not None:
            self._set_raw_status(cur, bid, call_id, existing_status)
            logger.warning(
                "[%s/%s] Failure after transcript existed; repaired raw status=%s",
                bid,
                call_id,
                existing_status,
            )
            return

        if self._is_recording_not_ready_error(error_message):
            cur.execute(
                f"""
                UPDATE {raw_table}
                SET status = 0,
                    transcription_status = 'url_wait',
                    transcription_requested = 0
                WHERE callid = %s
                """,
                (call_id,),
            )
            logger.warning(
                "[%s/%s] Recording file not ready (404); reset to status=0 for URL re-check",
                bid,
                call_id,
            )
            return

        if self._is_permanent_stt_failure(error_message):
            cur.execute(
                f"""
                UPDATE {raw_table}
                SET status = -2,
                    transcription_status = 'backlog_cleared',
                    transcription_requested = 0
                WHERE callid = %s
                """,
                (call_id,),
            )
            logger.warning(
                "[%s/%s] Permanent STT failure; marked terminal (backlog_cleared): %s",
                bid,
                call_id,
                error_message,
            )
            return

        cur.execute(
            f"SELECT transcription_status FROM {raw_table} WHERE callid = %s LIMIT 1",
            (call_id,),
        )
        row = cur.fetchone() or {}
        retry_count = self._parse_stt_retry_count(row.get("transcription_status")) + 1
        max_retries = max(1, int(getattr(self, "max_retries", 3) or 3))

        if retry_count >= max_retries:
            cur.execute(
                f"""
                UPDATE {raw_table}
                SET status = -2,
                    transcription_status = 'backlog_cleared',
                    transcription_requested = 0
                WHERE callid = %s
                """,
                (call_id,),
            )
            logger.warning(
                "[%s/%s] STT failed %s/%s times; marked terminal (backlog_cleared): %s",
                bid,
                call_id,
                retry_count,
                max_retries,
                error_message,
            )
            return

        cur.execute(
            f"""
            UPDATE {raw_table}
            SET status = 0,
                transcription_status = %s,
                transcription_requested = 0
            WHERE callid = %s
            """,
            (f"stt_retry:{retry_count}", call_id),
        )
        logger.warning(
            "[%s/%s] STT failed attempt %s/%s; reset to status=0 for retry: %s",
            bid,
            call_id,
            retry_count,
            max_retries,
            error_message,
        )

    # ---------------------------------------------------

    def _process_job(self, job: Dict[str, Any]):

        job_id = job.get("job_id")
        bid = job.get("bid")
        call_id = job.get("call_id")
        recording_url = job.get("recording_url")

        logger.info(
            "[%s/%s] Starting STT (job_id=%s)",
            bid,
            call_id,
            job_id
        )

        try:
            if job_id:
                stt_jobs.mark_processing(job_id)

            # Note: No separate mark_processing step since status 1 handles pending queue state
            with get_connection() as conn:
                with conn.cursor() as cur:
                    removed = self._purge_empty_transcript_rows(cur, bid, call_id)
                    if removed:
                        logger.info(
                            "[%s/%s] Removed %s empty sarvamresponse row(s) before STT",
                            bid,
                            call_id,
                            removed,
                        )
                    existing_status = self._existing_result_status(cur, bid, call_id)
                    if existing_status is not None:
                        self._set_raw_status(cur, bid, call_id, existing_status)
                        conn.commit()
                        logger.info(
                            "[%s/%s] Transcript already exists; set raw status=%s",
                            bid,
                            call_id,
                            existing_status,
                        )
                        self._maybe_enqueue_analytics(bid, call_id, existing_status)
                        return
                conn.commit()

            from stt.min_duration_gate import should_skip_stt

            skip_min, min_reason, probed_audio = should_skip_stt(bid, call_id, recording_url)
            if skip_min:
                with get_connection() as conn:
                    with conn.cursor() as cur:
                        raw_table = self._raw_table(bid)
                        cur.execute(
                            f"""
                            UPDATE {raw_table}
                            SET status = -2,
                                transcription_status = 'skipped_short',
                                transcription_requested = 0,
                                duration_seconds = %s
                            WHERE callid = %s
                            """,
                            (
                                max(0, int(round(probed_audio))) if probed_audio is not None else None,
                                call_id,
                            ),
                        )
                    conn.commit()
                logger.info(
                    "[%s/%s] STT skipped before provider call (%s)",
                    bid,
                    call_id,
                    min_reason,
                )
                return

            try:
                result = self.stt.transcribe(
                    audio_url=recording_url,
                    callid=call_id,
                )
            except SarvamSlotTimeoutError:
                with get_connection() as conn:
                    with conn.cursor() as cur:
                        existing_status = self._existing_result_status(cur, bid, call_id)
                        if existing_status is not None:
                            self._set_raw_status(cur, bid, call_id, existing_status)
                            conn.commit()
                            logger.info(
                                "[%s/%s] Slot wait timeout but transcript exists; status=%s",
                                bid,
                                call_id,
                                existing_status,
                            )
                            self._maybe_enqueue_analytics(bid, call_id, existing_status)
                            return
                        self._set_raw_status(cur, bid, call_id, 0)
                        conn.commit()
                logger.warning(
                    "[%s/%s] Sarvam slot wait timeout; reset to status=0 for re-queue",
                    bid,
                    call_id,
                )
                return

            speaker_count = len(
                {seg["speaker_id"] for seg in result.speaker_segments}
            )

            # Insert direct to sarvamresponse and update raw_calls status
            resp_table = self._response_table(bid)
            raw_table = self._raw_table(bid)
            
            with get_connection() as conn:
                with conn.cursor() as cur:
                    try:
                        cur.execute(
                            f"""
                            INSERT INTO {resp_table} 
                            (callid, transcript, speaker_segments, num_speakers, duration, 
                             stt_provider, status, request_id, raw_response, created_at, updated_at, call_starttime) 
                            SELECT 
                                %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), call_starttime
                            FROM {raw_table} WHERE callid = %s
                            """,
                            (
                                call_id, 
                                result.transcript, 
                                json.dumps(result.speaker_segments), 
                                speaker_count, 
                                result.duration, 
                                result.provider,
                                '1',
                                f"job_{call_id}_{int(time.time())}",
                                json.dumps({"transcript": result.transcript, "synced": True}),
                                call_id
                            )
                        )
                    except Exception as insert_exc:
                        if self._is_duplicate_key_error(insert_exc):
                            existing_status = self._existing_result_status(cur, bid, call_id)
                            if existing_status is not None:
                                self._set_raw_status(cur, bid, call_id, existing_status)
                                logger.info(
                                    "[%s/%s] Duplicate transcript row; set raw status=%s",
                                    bid,
                                    call_id,
                                    existing_status,
                                )
                                self._maybe_enqueue_analytics(bid, call_id, existing_status)
                                return
                            raise
                        # Some BID response tables don't have call_starttime.
                        # Fallback insert keeps transcription flow unblocked.
                        if not self._is_unknown_call_starttime_error(insert_exc):
                            raise
                        cur.execute(
                            f"""
                            INSERT INTO {resp_table} 
                            (callid, transcript, speaker_segments, num_speakers, duration, 
                             stt_provider, status, request_id, raw_response, created_at, updated_at) 
                            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
                            """,
                            (
                                call_id,
                                result.transcript,
                                json.dumps(result.speaker_segments),
                                speaker_count,
                                result.duration,
                                result.provider,
                                '1',
                                f"job_{call_id}_{int(time.time())}",
                                json.dumps({"transcript": result.transcript, "synced": True}),
                            )
                        )
                    
                    self._set_raw_status(cur, bid, call_id, 2)

            self._maybe_enqueue_analytics(bid, call_id, 2)

            if job_id:
                stt_jobs.mark_done(
                    job_id,
                    result.transcript,
                    result.speaker_segments,
                    speaker_count,
                    result.duration,
                    result.provider,
                )

            logger.info(
                "[%s/%s] Done — %.1fs, %d speakers, %d segments",
                bid,
                call_id,
                result.duration,
                speaker_count,
                len(result.speaker_segments)
            )

        except Exception as exc:

            error_message = str(exc)
            logger.error(
                "[%s/%s] Failed: %s",
                bid,
                call_id,
                error_message
            )
            if job_id:
                stt_jobs.mark_failed(job_id, error_message)
            
            with get_connection() as conn:
                with conn.cursor() as cur:
                    self._handle_stt_failure(cur, bid, call_id, error_message)
                conn.commit()
