#!/usr/bin/env python3
"""Run a short direct STT pass for one BID, bypassing RabbitMQ queue order."""
from __future__ import annotations

import argparse
import logging
import os
import signal
import sys

import pymysql
from dotenv import load_dotenv


BACKEND_DIR = os.path.dirname(os.path.abspath(__file__))
REPO_DIR = os.path.dirname(BACKEND_DIR)
STT_DIR = os.path.join(REPO_DIR, "call-proccessing", "stt_pipeline")
logger = logging.getLogger("dedicated_stt_pass")


def _load_env() -> None:
    load_dotenv(os.path.join(BACKEND_DIR, ".env"))
    load_dotenv(os.path.join(STT_DIR, ".env"), override=False)
    sys.path.insert(0, STT_DIR)


def _db_connection():
    from config.settings import settings

    return pymysql.connect(
        host=settings.db_host,
        port=settings.db_port,
        user=settings.db_user,
        password=settings.db_password,
        database=settings.db_name,
        cursorclass=pymysql.cursors.DictCursor,
        autocommit=False,
    )


def _fetch_calls(bid: str, limit: int):
    raw_table = f"`{bid}_raw_calls`"
    with _db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                f"""
                SELECT callid, fileurl
                FROM {raw_table}
                WHERE status = 1
                  AND fileurl IS NOT NULL
                  AND fileurl != ''
                ORDER BY id DESC
                LIMIT %s
                """,
                (limit,),
            )
            return cur.fetchall() or []


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--bid", required=True)
    parser.add_argument("--limit", type=int, default=7)
    parser.add_argument("--call-timeout", type=int, default=600)
    args = parser.parse_args()

    _load_env()
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    from stt.factory import get_stt_provider
    from workers.rabbitmq_transcription_worker import RabbitMQTranscriptionWorker

    worker = RabbitMQTranscriptionWorker.__new__(RabbitMQTranscriptionWorker)
    worker.stt = get_stt_provider()
    worker.max_retries = 0

    calls = _fetch_calls(str(args.bid), int(args.limit))
    print(f"Dedicated STT pass for BID {args.bid}: {len(calls)} call(s)")

    def _timeout_handler(signum, frame):
        raise TimeoutError(f"Call exceeded {args.call_timeout}s")

    signal.signal(signal.SIGALRM, _timeout_handler)
    for row in calls:
        call_id = str(row["callid"])
        print(f"Processing {args.bid}/{call_id}", flush=True)
        signal.alarm(max(1, int(args.call_timeout)))
        try:
            worker._process_job(
                {
                    "bid": str(args.bid),
                    "call_id": call_id,
                    "recording_url": row["fileurl"],
                }
            )
        finally:
            signal.alarm(0)
    print("Dedicated STT pass complete")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
