#!/usr/bin/env python3
"""Re-transcribe calls for a BID/date with Sarvam and refresh analytics."""

import argparse
import json
import os
from datetime import datetime

import pymysql
from dotenv import load_dotenv
from pymysql.cursors import DictCursor

from analyze_calls_with_parameters import CallAnalyzer
from config import Config
from stt.sarvam import SarvamSTT


def db_connection(config):
    return pymysql.connect(
        host=config.DB_HOST,
        port=config.DB_PORT,
        user=config.DB_USER,
        password=config.DB_PASSWORD,
        database=config.DB_NAME,
        charset="utf8mb4",
        cursorclass=DictCursor,
        autocommit=True,
    )


def get_calls(conn, bid, call_date, limit):
    query = f"""
        SELECT callid, fileurl, call_starttime
        FROM `{bid}_raw_calls`
        WHERE DATE(call_starttime) = %s
          AND fileurl IS NOT NULL
          AND fileurl != ''
          AND call_status = 'ANSWER'
        ORDER BY call_starttime DESC
        LIMIT %s
    """
    with conn.cursor() as cursor:
        cursor.execute(query, (call_date, limit))
        return cursor.fetchall()


def save_transcript(conn, bid, callid, result):
    segments = result.speaker_segments or []
    transcript = result.transcript or "\n\n".join(
        f"{seg.get('speaker', 'Speaker')}: {seg.get('text', '')}".strip()
        for seg in segments
        if seg.get("text")
    )
    num_speakers = len({seg.get("speaker_id") or seg.get("speaker") for seg in segments}) or None

    payload = {
        "provider": result.provider,
        "speaker_segments": segments,
        "processed_at": datetime.utcnow().isoformat(),
    }

    with conn.cursor() as cursor:
        cursor.execute(
            f"""
            INSERT INTO `{bid}_sarvamresponse`
                (callid, transcript, raw_transcript, speaker_segments, num_speakers,
                 duration, request_id, language, stt_provider, language_detected,
                 raw_response, status, created_at, updated_at)
            VALUES
                (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, '1', NOW(), NOW())
            ON DUPLICATE KEY UPDATE
                transcript = VALUES(transcript),
                raw_transcript = VALUES(raw_transcript),
                speaker_segments = VALUES(speaker_segments),
                num_speakers = VALUES(num_speakers),
                duration = VALUES(duration),
                request_id = VALUES(request_id),
                language = VALUES(language),
                stt_provider = VALUES(stt_provider),
                language_detected = VALUES(language_detected),
                raw_response = VALUES(raw_response),
                status = VALUES(status),
                updated_at = NOW()
            """,
            (
                callid,
                transcript,
                transcript,
                json.dumps(segments, ensure_ascii=False),
                num_speakers,
                result.duration,
                f"sarvam_{callid}",
                "en",
                "sarvam",
                "en",
                json.dumps(payload, ensure_ascii=False),
            ),
        )
        cursor.execute(
            f"""
            UPDATE `{bid}_raw_calls`
            SET transcription_status = 'completed', status = 1
            WHERE callid = %s
            """,
            (callid,),
        )


class ConfigWrapper:
    def __init__(self, config):
        self._config = config

    def get(self, key, default=None):
        return getattr(self._config, key, default)

    def __getattr__(self, key):
        return getattr(self._config, key)


def analyze_call(analyzer, bid, callid, result):
    analyzer.analyze_call(
        bid=bid,
        callid=callid,
        transcript=result.transcript,
        speaker_segments=result.speaker_segments or [],
        actual_duration=result.duration,
    )


def get_saved_transcripts(conn, bid, call_date, limit):
    query = f"""
        SELECT r.callid, s.transcript, s.speaker_segments, s.duration
        FROM `{bid}_raw_calls` r
        JOIN `{bid}_sarvamresponse` s ON r.callid = s.callid
        WHERE DATE(r.call_starttime) = %s
          AND s.stt_provider = 'sarvam'
        ORDER BY r.call_starttime DESC
        LIMIT %s
    """
    with conn.cursor() as cursor:
        cursor.execute(query, (call_date, limit))
        return cursor.fetchall()


def analyze_saved_transcript(analyzer, bid, row):
    segments = row.get("speaker_segments")
    if isinstance(segments, str) and segments.strip():
        try:
            segments = json.loads(segments)
        except json.JSONDecodeError:
            segments = []
    analyzer.analyze_call(
        bid=bid,
        callid=str(row["callid"]),
        transcript=row.get("transcript") or "",
        speaker_segments=segments or [],
        actual_duration=float(row.get("duration") or 0),
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--bid", required=True)
    parser.add_argument("--date", required=True)
    parser.add_argument("--limit", type=int, default=10)
    parser.add_argument("--analysis-only", action="store_true")
    args = parser.parse_args()

    load_dotenv()
    config = Config()
    sarvam_key = os.getenv("SARVAM_SUBSCRIPTION_KEY") or getattr(config, "SARVAM_SUBSCRIPTION_KEY", "")
    stt = SarvamSTT(sarvam_key)
    analyzer = CallAnalyzer(ConfigWrapper(config))

    conn = db_connection(config)
    try:
        if args.analysis_only:
            rows = get_saved_transcripts(conn, args.bid, args.date, args.limit)
            print(f"Found {len(rows)} Sarvam transcripts for analytics")
            for index, row in enumerate(rows, 1):
                callid = str(row["callid"])
                print(f"[{index}/{len(rows)}] Refreshing analytics for {callid}")
                analyze_saved_transcript(analyzer, args.bid, row)
            return

        calls = get_calls(conn, args.bid, args.date, args.limit)
        print(f"Found {len(calls)} calls for BID {args.bid} on {args.date}")
        for index, call in enumerate(calls, 1):
            callid = str(call["callid"])
            print(f"[{index}/{len(calls)}] Sarvam transcribing {callid}")
            result = stt.transcribe(call["fileurl"], callid)
            save_transcript(conn, args.bid, callid, result)
            print(f"[{index}/{len(calls)}] Saved Sarvam transcript for {callid}")
            analyze_call(analyzer, args.bid, callid, result)
            print(f"[{index}/{len(calls)}] Refreshed analytics for {callid}")
    finally:
        conn.close()


if __name__ == "__main__":
    main()
