#!/usr/bin/env python3
"""Backfill LeadSquared activities for already analyzed calls.

Normal orchestration pushes to LSQ only immediately after analytics completes.
Use this utility to retry analyzed calls whose push was skipped earlier.
"""
from __future__ import annotations

import argparse
import json
import logging
import os
import sys
from typing import Any, Dict, Iterable, Optional

import pymysql
from dotenv import load_dotenv
from pymysql.cursors import DictCursor

BACKEND_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BACKEND_DIR)
os.chdir(BACKEND_DIR)
load_dotenv()

from config import Config
from db_handler import DatabaseHandler
from leadsquared_service import LeadSquaredService
from leadsquared_activity_push import (
    _lsq_build_activity_payload,
    _lsq_resolve_prospect_id,
    _push_leadsquared_activities,
)

logger = logging.getLogger("push_lsq_backfill")


def _config_dict() -> Dict[str, Any]:
    return {key: getattr(Config, key) for key in dir(Config) if key.isupper()}


def _parse_json(value: Any) -> Dict[str, Any]:
    if isinstance(value, dict):
        return value
    if not value:
        return {}
    try:
        parsed = json.loads(value)
        return parsed if isinstance(parsed, dict) else {}
    except Exception:
        return {}


def _analysis_from_row(row: Dict[str, Any]) -> Dict[str, Any]:
    analysis = _parse_json(row.get("raw_response"))
    for key in (
        "summary",
        "quality_score",
        "sentiment",
        "call_purpose",
        "objections_concerns",
        "objection_type",
        "talk_listen_ratio",
    ):
        if row.get(key) is not None and analysis.get(key) is None:
            analysis[key] = row.get(key)
    if row.get("summary") and analysis.get("overall_summary") is None:
        analysis["overall_summary"] = row.get("summary")
    return analysis


def _call_from_row(row: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "callid": row.get("callid"),
        "direction": row.get("direction") or "inbound",
        "agentname": row.get("agentname") or "",
        "callfrom": row.get("agent_callinfo"),
        "callto": row.get("customer_callinfo"),
        "clicktocalldid": row.get("customer_callinfo"),
        "customer_phone": row.get("customer_callinfo"),
        "agent_callinfo": row.get("agent_callinfo"),
        "customer_callinfo": row.get("customer_callinfo"),
        "call_starttime": row.get("call_starttime"),
    }


def _db_connect() -> pymysql.Connection:
    return pymysql.connect(
        host=os.getenv("DB_HOST"),
        port=int(os.getenv("DB_PORT", "3306")),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASSWORD"),
        database=os.getenv("DB_NAME"),
        charset="utf8mb4",
        cursorclass=DictCursor,
    )


def _load_candidates(bid: str, limit: int, callid: Optional[str] = None) -> Iterable[Dict[str, Any]]:
    q = "`"
    raw_table = f"{q}{bid}_raw_calls{q}"
    analytics_table = f"{q}{bid}_callanalytics{q}"
    params: list[Any] = []
    where = "r.status = 3"
    if callid:
        where += " AND r.callid = %s"
        params.append(callid)
    params.append(limit)

    query = f"""
        SELECT
            r.callid, r.agentname, r.agent_callinfo, r.customer_callinfo, r.direction,
            r.call_starttime, a.raw_response, a.summary, a.quality_score,
            a.sentiment, a.call_purpose, a.objections_concerns,
            a.objection_type, a.talk_listen_ratio
        FROM {raw_table} r
        INNER JOIN {analytics_table} a ON a.callid = r.callid
        WHERE {where}
        ORDER BY r.call_starttime DESC
        LIMIT %s
    """
    with _db_connect() as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, params)
            return cursor.fetchall() or []


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--bid", required=True)
    parser.add_argument("--callid")
    parser.add_argument("--limit", type=int, default=10)
    parser.add_argument("--execute", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
    db = DatabaseHandler(_config_dict())
    integration = db.get_crm_integration(args.bid, "leadsquared") or {}
    config = integration.get("config") or {}
    activities = config.get("activities") or {}
    creds = db.get_crm_credentials(args.bid, "leadsquared") or {}
    service = LeadSquaredService(
        access_key=creds.get("access_key"),
        secret_key=creds.get("secret_key"),
        api_host=creds.get("api_host"),
        timeout=30,
    )

    pushed = 0
    skipped = 0
    for row in _load_candidates(args.bid, args.limit, args.callid):
        call = _call_from_row(row)
        analysis = _analysis_from_row(row)
        prospect_id = _lsq_resolve_prospect_id(service, call, analysis, db=db, bid=args.bid)
        if not prospect_id:
            skipped += 1
            logger.info("[%s] skip: no LSQ prospect match for phone=%s", call["callid"], call.get("customer_phone"))
            continue

        if not args.execute:
            logger.info("[%s] dry-run prospect=%s", call["callid"], prospect_id)
            for activity_key, activity_cfg in activities.items():
                payload = _lsq_build_activity_payload(activity_cfg, analysis, call, prospect_id)
                if payload:
                    logger.info("[%s] dry-run activity=%s payload=%s", call["callid"], activity_key, json.dumps(payload, ensure_ascii=False))
            continue

        _push_leadsquared_activities(args.bid, call, analysis, db)
        pushed += 1

    logger.info("Done. pushed=%s skipped=%s dry_run=%s", pushed, skipped, not args.execute)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
