#!/usr/bin/env python3
"""
Re-transcribe calls whose speaker_segments look like chunked-sync fake diarization.

Usage:
  python3 repair_bad_diarization.py --bid 6004 [--limit 20] [--dry-run]
"""
from __future__ import annotations

import argparse
import json
import os
import sys
import time

import pymysql
from dotenv import load_dotenv

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
STT_DIR = os.path.join(os.path.dirname(BASE_DIR), "call-proccessing", "stt_pipeline")
sys.path.insert(0, STT_DIR)

load_dotenv(os.path.join(BASE_DIR, ".env"))
load_dotenv(os.path.join(STT_DIR, ".env"), override=True)

from stt.sarvam import validate_diarization, DiarizationQualityError, _looks_like_chunk_aligned_blocks  # noqa: E402
from stt.factory import get_stt_provider  # noqa: E402


def db_conn():
    return pymysql.connect(
        host=os.getenv("DB_HOST", "127.0.0.1"),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASSWORD"),
        database=os.getenv("DB_NAME"),
        cursorclass=pymysql.cursors.DictCursor,
    )


def parse_segments(raw) -> list:
    if not raw:
        return []
    if isinstance(raw, str):
        try:
            raw = json.loads(raw)
        except json.JSONDecodeError:
            return []
    return raw if isinstance(raw, list) else []


def is_bad_diarization(segments: list, duration: float) -> bool:
    if duration <= 0:
        return False
    if not segments:
        return duration > 10
    try:
        from stt.base import STTResult
        result = STTResult(
            transcript="x",
            speaker_segments=segments,
            duration=duration,
            provider="sarvam",
        )
        validate_diarization(result, duration, "check")
        return False
    except DiarizationQualityError:
        return True


def collect_bad_calls(cur, bid: str, limit: int) -> list:
    table = f"{bid}_sarvamresponse"
    cur.execute(f"SHOW TABLES LIKE %s", (table,))
    if not cur.fetchone():
        return []
    cur.execute(
        f"""
        SELECT s.callid, s.speaker_segments, s.duration, r.fileurl, r.status
        FROM `{table}` s
        INNER JOIN `{bid}_raw_calls` r ON r.callid = s.callid
        WHERE r.call_status = 'ANSWER'
          AND r.fileurl IS NOT NULL AND r.fileurl != ''
        ORDER BY r.call_starttime DESC
        LIMIT %s
        """,
        (limit * 10,),
    )
    bad = []
    for row in cur.fetchall() or []:
        segs = parse_segments(row.get("speaker_segments"))
        dur = float(row.get("duration") or 0)
        if dur <= 0 and segs:
            dur = float(segs[-1].get("end") or segs[-1].get("end_time") or 0)
        if is_bad_diarization(segs, dur) or _looks_like_chunk_aligned_blocks(segs, dur):
            bad.append(row)
        if len(bad) >= limit:
            break
    return bad


def save_result(cur, bid: str, callid: str, result) -> None:
    speaker_count = len({s.get("speaker_id") for s in result.speaker_segments})
    cur.execute(
        f"""
        UPDATE `{bid}_sarvamresponse`
        SET transcript=%s, speaker_segments=%s, num_speakers=%s, duration=%s,
            stt_provider=%s, request_id=%s, raw_response=%s, updated_at=NOW()
        WHERE callid=%s
        """,
        (
            result.transcript,
            json.dumps(result.speaker_segments),
            speaker_count,
            result.duration,
            result.provider,
            f"diarize_repair_{callid}_{int(time.time())}",
            json.dumps({"transcript": result.transcript, "diarization_repair": True}),
            callid,
        ),
    )
    cur.execute(f"UPDATE `{bid}_raw_calls` SET status=2 WHERE callid=%s", (callid,))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--bid", required=True)
    parser.add_argument("--limit", type=int, default=20)
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    conn = db_conn()
    try:
        with conn.cursor() as cur:
            bad = collect_bad_calls(cur, args.bid, args.limit)
        print(f"BID {args.bid}: found {len(bad)} call(s) with bad diarization")
        for row in bad:
            print(f"  - {row['callid']} (status={row.get('status')})")

        if args.dry_run or not bad:
            return

        stt = get_stt_provider()
        fixed = 0
        for row in bad:
            callid = row["callid"]
            url = row["fileurl"]
            print(f"Re-transcribing {callid}…")
            try:
                result = stt.transcribe(url, callid)
                with conn.cursor() as cur:
                    save_result(cur, args.bid, callid, result)
                conn.commit()
                fixed += 1
                print(f"  OK: {len(result.speaker_segments)} segments")
            except Exception as exc:
                print(f"  FAIL: {exc}")
                conn.rollback()
            time.sleep(2)
        print(f"Fixed {fixed}/{len(bad)} calls")
    finally:
        conn.close()


if __name__ == "__main__":
    main()
