#!/usr/bin/env python3
"""Diagnose transcription pipeline state for a BID (read-only)."""
import argparse
import os
import sys

import pymysql
import pika
from dotenv import load_dotenv
from pymysql.cursors import DictCursor

load_dotenv()


def main():
    parser = argparse.ArgumentParser(description="Diagnose STT/transcription for a BID")
    parser.add_argument("--bid", required=True, help="Business ID, e.g. 8398")
    args = parser.parse_args()
    bid = str(args.bid).strip()

    from config import Config

    cfg = Config()
    conn = pymysql.connect(
        host=cfg.DB_HOST,
        port=cfg.DB_PORT,
        user=cfg.DB_USER,
        password=cfg.DB_PASSWORD,
        database=cfg.DB_NAME,
        cursorclass=DictCursor,
    )
    cur = conn.cursor()
    raw = f"{bid}_raw_calls"
    resp = f"{bid}_sarvamresponse"

    print(f"\n=== Transcription diagnostic: BID {bid} ===\n")

    cur.execute(f"SHOW TABLES LIKE %s", (raw,))
    if not cur.fetchone():
        print(f"ERROR: Table {raw} does not exist.")
        sys.exit(1)

    print("raw_calls by status:")
    cur.execute(f"SELECT status, COUNT(*) AS n FROM `{raw}` GROUP BY status ORDER BY status")
    for row in cur.fetchall() or []:
        label = {
            -2: "STT failed",
            -1: "invalid URL",
            0: "ingested, not queued",
            1: "queued (waiting for STT worker)",
            2: "transcribed",
            3: "analytics done",
        }.get(int(row["status"]), "unknown")
        print(f"  status {row['status']:>2} ({label}): {row['n']}")

    cur.execute(f"SHOW TABLES LIKE %s", (resp,))
    if cur.fetchone():
        cur.execute(f"SELECT COUNT(*) AS n FROM `{resp}`")
        total = int((cur.fetchone() or {}).get("n") or 0)
        cur.execute(
            f"SELECT COUNT(*) AS n FROM `{resp}` "
            "WHERE transcript IS NOT NULL AND TRIM(transcript) != ''"
        )
        with_tx = int((cur.fetchone() or {}).get("n") or 0)
        print(f"\nsarvamresponse: {with_tx}/{total} rows with non-empty transcript")
    else:
        print(f"\nWARNING: Table {resp} does not exist.")

    cur.execute(
        f"""
        SELECT COUNT(*) AS n FROM `{raw}` r
        WHERE r.status = 1
          AND EXISTS (
            SELECT 1 FROM `{resp}` s
            WHERE s.callid = r.callid
              AND s.transcript IS NOT NULL
              AND TRIM(s.transcript) != ''
          )
        """
    )
    phantom = int((cur.fetchone() or {}).get("n") or 0)
    if phantom:
        print(f"\nWARNING: {phantom} call(s) status=1 but transcript already exists (run orchestrator repair).")

    cur.execute(f"SELECT COUNT(*) AS n FROM `{raw}` WHERE status = 1")
    queued = int((cur.fetchone() or {}).get("n") or 0)

    conn.close()

    queue = os.getenv("RABBITMQ_QUEUE", "stt_jobs")
    host = os.getenv("RABBITMQ_HOST", "localhost")
    try:
        rmq = pika.BlockingConnection(pika.ConnectionParameters(host=host))
        ch = rmq.channel()
        q = ch.queue_declare(queue=queue, durable=True, passive=True)
        depth = int(q.method.message_count)
        consumers = int(q.method.consumer_count)
        rmq.close()
        print(f"\nRabbitMQ {queue} @ {host}: messages={depth} consumers={consumers}")
        if queued and depth > 100:
            print(
                f"\nLIKELY CAUSE: {queued} call(s) for BID {bid} are queued (status=1) but the shared "
                f"STT queue has {depth} messages. One worker processes ~1 job at a time; "
                "transcription is delayed until the backlog clears."
            )
        elif queued and consumers == 0:
            print("\nLIKELY CAUSE: STT worker is not running. Start:")
            print("  cd call-proccessing/stt_pipeline && python run.py --worker")
    except Exception as e:
        print(f"\nCould not inspect RabbitMQ: {e}")

    print()


if __name__ == "__main__":
    main()
