"""
Re-analyze calls that already have analytics to fix the 0:100 talk-to-listen ratio.

Root cause: old speaker_segments stored in {bid}_sarvamresponse have no 'role' or
'speaker_id' fields, and the speaker display name was stored as "Speaker 1" (1-indexed)
instead of "Speaker 0" due to a +1 offset bug. The talk_listen_calculator therefore
could not identify the agent and assigned all time to customer → 0:100.

This script:
  1. Fetches all calls that have both a transcript and existing analytics.
  2. Normalizes old-format speaker_segments before passing to the analyzer.
  3. Re-runs analyze_call(), which upserts (overwrites) the analytics row.

Usage:
    python reanalyze_talk_listen.py --bid 6004
    python reanalyze_talk_listen.py --bid 6004 --limit 100 --delay 0.5
    python reanalyze_talk_listen.py --bid 6004 --dry-run
"""
import argparse
import json
import logging
import time

from config import Config
from db_handler import DatabaseHandler
from analyze_calls_with_parameters import CallAnalyzer

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class ConfigWrapper:
    def __init__(self, config):
        self._config = config

    def get(self, key, default=None):
        return getattr(self._config, key, default)

    def __getattr__(self, key):
        return getattr(self._config, key)


def fetch_already_analyzed_calls(db_handler, bid, limit):
    """Return callids that have transcripts AND existing analytics (need ratio fix)."""
    query = f"""
        SELECT r.callid
        FROM `{bid}_raw_calls` r
        JOIN `{bid}_sarvamresponse` s ON r.callid = s.callid
        JOIN `{bid}_callanalytics` a ON r.callid = a.callid
        WHERE r.call_status = 'ANSWER'
          AND s.transcript IS NOT NULL
          AND s.transcript != ''
        ORDER BY r.call_starttime DESC
        LIMIT %s
    """
    with db_handler.get_connection() as conn:
        cursor = conn.cursor()
        cursor.execute(query, (limit,))
        return cursor.fetchall()


def normalize_segments(speaker_segments):
    """
    Normalize speaker_segments so that 'role' is correctly assigned.

    The agent is always the speaker with the lowest numeric speaker ID.
    Sarvam may use 0-indexed (speaker_0, speaker_1) or 1-indexed (speaker_1, speaker_2).
    We never trust the stored 'role' because it was set with a hardcoded speaker_0
    assumption that fails when Sarvam uses 1-indexed IDs.
    """
    normalized = []
    for seg in speaker_segments:
        seg = dict(seg)  # don't mutate original

        # Resolve speaker_id from stored field or display name
        if seg.get('speaker_id'):
            pass  # already have it
        else:
            speaker_display = seg.get('speaker', '').strip().lower()
            speaker_num = None
            for prefix in ('speaker_', 'speaker '):
                if speaker_display.startswith(prefix):
                    try:
                        speaker_num = int(speaker_display[len(prefix):])
                    except ValueError:
                        pass
                    break
            if speaker_num is not None:
                seg['speaker_id'] = f'speaker_{speaker_num}'
            else:
                seg['speaker_id'] = 'unknown'

        normalized.append(seg)

    # Find the minimum numeric speaker ID — that speaker is the agent
    nums = []
    for seg in normalized:
        try:
            nums.append(int(seg['speaker_id'].replace('speaker_', '')))
        except (ValueError, AttributeError):
            pass

    min_num = min(nums) if nums else 0

    for seg in normalized:
        try:
            n = int(seg['speaker_id'].replace('speaker_', ''))
        except (ValueError, AttributeError):
            n = -1
        seg['role'] = 'agent' if n == min_num else 'customer'

    return normalized


def main():
    parser = argparse.ArgumentParser(
        description='Re-analyze calls to fix 0:100 talk-to-listen ratio'
    )
    parser.add_argument('--bid', required=True, help='Business ID (e.g. 6004)')
    parser.add_argument('--limit', type=int, default=500,
                        help='Max calls to re-analyze (default: 500)')
    parser.add_argument('--delay', type=float, default=0.0,
                        help='Seconds to sleep between calls (default: 0)')
    parser.add_argument('--dry-run', action='store_true',
                        help='Fetch and normalize segments but do not call the AI or save')
    args = parser.parse_args()

    config = Config()
    config_wrapped = ConfigWrapper(config)
    db_handler = DatabaseHandler(config_wrapped)
    analyzer = CallAnalyzer(config_wrapped)

    logger.info("=" * 70)
    logger.info("RE-ANALYZE TALK-TO-LISTEN RATIO")
    logger.info("BID: %s | limit: %s | dry-run: %s", args.bid, args.limit, args.dry_run)
    logger.info("=" * 70)

    calls = fetch_already_analyzed_calls(db_handler, args.bid, args.limit)
    if not calls:
        logger.info("No analyzed calls found for BID %s.", args.bid)
        return

    logger.info("Found %s calls to re-analyze.", len(calls))

    success = 0
    failed = 0
    skipped = 0

    for idx, row in enumerate(calls, 1):
        callid = row['callid']
        logger.info("[%s/%s] Processing call %s", idx, len(calls), callid)

        try:
            call_data = db_handler.get_raw_call_details(args.bid, callid)
            if not call_data:
                logger.warning("  Call %s not found, skipping.", callid)
                skipped += 1
                continue

            transcript = call_data.get('transcripts')
            if not transcript:
                logger.warning("  Call %s has no transcript, skipping.", callid)
                skipped += 1
                continue

            raw_segments = call_data.get('speaker_segments') or []
            if isinstance(raw_segments, str):
                try:
                    raw_segments = json.loads(raw_segments)
                except (json.JSONDecodeError, TypeError):
                    raw_segments = []

            speaker_segments = normalize_segments(raw_segments)
            actual_duration = call_data.get('duration') or call_data.get('duration_seconds')

            # Quick sanity-check: log agent vs customer segment counts
            agent_segs = sum(1 for s in speaker_segments if s.get('role') == 'agent')
            cust_segs = sum(1 for s in speaker_segments if s.get('role') == 'customer')
            logger.info("  Segments — agent: %s, customer: %s, duration: %ss",
                        agent_segs, cust_segs, actual_duration)

            if args.dry_run:
                logger.info("  [dry-run] Skipping AI call and DB save.")
                success += 1
                continue

            analyzer.analyze_call(
                bid=args.bid,
                callid=callid,
                transcript=transcript,
                speaker_segments=speaker_segments,
                actual_duration=actual_duration
            )

            success += 1
            if args.delay > 0 and idx < len(calls):
                time.sleep(args.delay)

        except Exception as exc:
            logger.error("  Failed for call %s: %s", callid, exc)
            failed += 1

    logger.info("=" * 70)
    logger.info("DONE — success: %s | skipped: %s | failed: %s", success, skipped, failed)
    logger.info("=" * 70)


if __name__ == '__main__':
    main()
