#!/usr/bin/env python3
"""
Reset a poisoned stt_jobs queue and re-publish one job per call that still needs Sarvam STT.

Use when the queue has thousands of duplicate messages and workers cannot catch up.

  python3 repair_stt_backlog.py --dry-run
  python3 repair_stt_backlog.py
  python3 repair_stt_backlog.py --bid 6004
"""
from __future__ import annotations

import argparse
import json
import logging
import os
import sys

import pika
import pymysql
from dotenv import load_dotenv
from pymysql.cursors import DictCursor

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BASE_DIR)
load_dotenv(os.path.join(BASE_DIR, ".env"))

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("repair_stt_backlog")


def _has_transcript_sql(resp_table: str) -> str:
    return (
        f"EXISTS (SELECT 1 FROM `{resp_table}` s "
        f"WHERE s.callid = r.callid "
        f"AND s.transcript IS NOT NULL AND TRIM(s.transcript) != '')"
    )


def get_bids(only_bid: str | None) -> list[str]:
    if only_bid:
        return [only_bid.strip()]
    from db_handler import DatabaseHandler
    from config import Config

    class ConfigWrapper:
        def __init__(self, cfg):
            self._cfg = cfg

        def get(self, key, default=None):
            return getattr(self._cfg, key, default)

    handler = DatabaseHandler(ConfigWrapper(Config()))
    handler.ensure_business_pipeline_config_table()
    return handler.get_enabled_pipeline_bids()


def get_duration_filter(bid: str) -> tuple[int, datetime | None]:
    from datetime import datetime

    from db_handler import DatabaseHandler
    from config import Config

    class ConfigWrapper:
        def __init__(self, cfg):
            self._cfg = cfg

        def get(self, key, default=None):
            return getattr(self._cfg, key, default)

    handler = DatabaseHandler(ConfigWrapper(Config()))
    handler.ensure_business_pipeline_config_table()
    cfg = handler.get_pipeline_config(bid) or {}
    raw = cfg.get("min_call_duration_s")
    min_s = max(0, int(raw)) if raw is not None else 0
    if min_s <= 0:
        return 0, None
    effective_at = cfg.get("min_call_duration_effective_at")
    if isinstance(effective_at, datetime):
        return min_s, effective_at.replace(tzinfo=None) if effective_at.tzinfo else effective_at
    if effective_at:
        try:
            parsed = datetime.fromisoformat(str(effective_at).replace("Z", "+00:00"))
            return min_s, parsed.replace(tzinfo=None) if parsed.tzinfo else parsed
        except Exception:
            return min_s, None
    return min_s, None


def collect_calls(
    cur,
    bid: str,
    limit: int,
    min_duration_s: int = 0,
    effective_at=None,
) -> list[dict]:
    raw = f"{bid}_raw_calls"
    resp = f"{bid}_sarvamresponse"
    has_tx = _has_transcript_sql(resp)
    duration_clause = ""
    query_params: list = []
    if min_duration_s > 0:
        duration_clause = """
          AND COALESCE(
            NULLIF(r.duration_seconds, 0),
            CASE
              WHEN r.call_starttime IS NOT NULL AND r.call_endtime IS NOT NULL
              THEN TIMESTAMPDIFF(SECOND, r.call_starttime, r.call_endtime)
              ELSE NULL
            END
          ) >= %s
        """
        query_params.append(min_duration_s)
    cur.execute(
        f"""
        SELECT r.callid, r.fileurl
        FROM `{raw}` r
        WHERE r.fileurl IS NOT NULL AND TRIM(r.fileurl) != ''
          AND r.status IN (0, 1, -2)
          AND NOT ({has_tx})
          {duration_clause}
        ORDER BY r.call_starttime ASC
        LIMIT %s
        """,
        tuple(query_params + [limit]),
    )
    return list(cur.fetchall() or [])


def reset_queued_without_transcript(cur, bid: str, dry_run: bool) -> int:
    raw = f"{bid}_raw_calls"
    resp = f"{bid}_sarvamresponse"
    has_tx = _has_transcript_sql(resp)
    sql = f"""
        UPDATE `{raw}` r
        SET status = 0
        WHERE r.status IN (1, -2)
          AND r.fileurl IS NOT NULL AND TRIM(r.fileurl) != ''
          AND NOT ({has_tx})
    """
    if dry_run:
        cur.execute(
            f"SELECT COUNT(*) AS n FROM `{raw}` r WHERE r.status IN (1, -2) "
            f"AND r.fileurl IS NOT NULL AND TRIM(r.fileurl) != '' AND NOT ({has_tx})"
        )
        return int((cur.fetchone() or {}).get("n") or 0)
    cur.execute(sql)
    return int(cur.rowcount or 0)


def purge_queue(channel, queue: str, dry_run: bool) -> int:
    if dry_run:
        q = channel.queue_declare(queue=queue, durable=True, passive=True)
        return int(q.method.message_count)
    q = channel.queue_purge(queue=queue)
    return int(q.method.message_count)


def publish_jobs(channel, queue: str, jobs: list[dict], dry_run: bool) -> int:
    published = 0
    for job in jobs:
        if dry_run:
            published += 1
            continue
        channel.basic_publish(
            exchange="",
            routing_key=queue,
            body=json.dumps(job),
            properties=pika.BasicProperties(delivery_mode=2),
        )
        published += 1
    return published


def mark_queued(cur, bid: str, callids: list[str], dry_run: bool) -> int:
    if not callids or dry_run:
        return len(callids)
    placeholders = ", ".join(["%s"] * len(callids))
    cur.execute(
        f"UPDATE `{bid}_raw_calls` SET status = 1 WHERE callid IN ({placeholders})",
        tuple(callids),
    )
    return int(cur.rowcount or 0)


def main() -> None:
    parser = argparse.ArgumentParser(description="Purge stt_jobs backlog and re-queue from DB")
    parser.add_argument("--bid", help="Only repair this BID (default: all pipeline_enabled)")
    parser.add_argument("--limit", type=int, default=5000, help="Max calls to re-queue per BID")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    from config import Config

    cfg = Config()
    queue = os.getenv("RABBITMQ_QUEUE", "stt_jobs")
    host = os.getenv("RABBITMQ_HOST", "localhost")

    conn = pymysql.connect(
        host=cfg.DB_HOST,
        port=cfg.DB_PORT,
        user=cfg.DB_USER,
        password=cfg.DB_PASSWORD,
        database=cfg.DB_NAME,
        cursorclass=DictCursor,
        autocommit=False,
    )
    bids = get_bids(args.bid)
    if not bids:
        logger.error("No BIDs to repair")
        sys.exit(1)

    rmq = pika.BlockingConnection(pika.ConnectionParameters(host=host))
    channel = rmq.channel()
    channel.queue_declare(queue=queue, durable=True)

    purged = purge_queue(channel, queue, args.dry_run)
    logger.info("%s purged %s message(s) from %s", "Would have" if args.dry_run else "Purged", purged, queue)

    total_jobs: list[dict] = []
    try:
        with conn.cursor() as cur:
            for bid in bids:
                reset_n = reset_queued_without_transcript(cur, bid, args.dry_run)
                logger.info(
                    "BID %s: %s reset %s row(s) from status 1/-2 → 0",
                    bid,
                    "Would" if args.dry_run else "Reset",
                    reset_n,
                )
                calls = collect_calls(cur, bid, args.limit)
                for call in calls:
                    total_jobs.append(
                        {
                            "bid": bid,
                            "call_id": str(call["callid"]),
                            "recording_url": call["fileurl"],
                        }
                    )
                if not args.dry_run and calls:
                    mark_queued(cur, bid, [str(c["callid"]) for c in calls], dry_run=False)
                logger.info("BID %s: %s call(s) to publish", bid, len(calls))
            if not args.dry_run:
                conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()

    published = publish_jobs(channel, queue, total_jobs, args.dry_run)
    rmq.close()

    logger.info(
        "%s %s job(s) on %s for BID(s): %s",
        "Would publish" if args.dry_run else "Published",
        published,
        queue,
        ", ".join(bids),
    )
    if not args.dry_run:
        logger.info("Restart STT workers: sudo systemctl restart mcube-stt-worker")


if __name__ == "__main__":
    main()
