#!/usr/bin/env python3
"""
Generic multi-bid call processing pipeline.

Replaces the hardcoded ``pipeline_6004.py`` with a DB-driven design.
Pipeline configuration is loaded from ``business_pipeline_config``; AI agent
prompts from ``business_agent_config``; all results land in
``{bid}_call_records``.

Stages per bid
--------------
1. SYNC  — pull new ANSWER calls from source DB into ``{bid}_call_records``
           (watermarked; optionally filtered to CRM-known phones)
2. TRANSCRIBE — fetch audio, call STT provider, save transcript
3. ANALYZE    — run all enabled AI agents via AgentRunner, save analysis

Usage
-----
    python3 call_processor.py                  # process all enabled bids once
    python3 call_processor.py --continuous     # loop forever
    python3 call_processor.py --bid 6004       # single bid
    python3 call_processor.py --bid 6004 --stage sync
    python3 call_processor.py --interval 120
"""
from __future__ import annotations

import argparse
import json
import logging
import os
import signal
import sys
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set

import pymysql
from pymysql.cursors import DictCursor

# ── bootstrap: must run from dashboard-backend dir ───────────────────────────
BACKEND_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BACKEND_DIR)
os.chdir(BACKEND_DIR)

from dotenv import load_dotenv
load_dotenv()

from db_handler import DatabaseHandler
from leadsquared_activity_push import push_leadsquared_activities as _push_leadsquared_activities
from stt import get_stt_provider
from agent_runner import AgentRunner

# ── logging ───────────────────────────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(BACKEND_DIR, "call_processor.log")),
        logging.StreamHandler(),
    ],
)
logger = logging.getLogger("call_processor")

# ── graceful shutdown ─────────────────────────────────────────────────────────
_shutdown = False

def _handle_signal(sig, frame):
    global _shutdown
    logger.info("Shutdown signal received — finishing current batch…")
    _shutdown = True

signal.signal(signal.SIGTERM, _handle_signal)
signal.signal(signal.SIGINT, _handle_signal)


# ── phone normalisation (mirrors db_handler._normalize_phone_variants) ────────

def _phone_variants(phone: Any) -> Set[str]:
    digits = "".join(ch for ch in str(phone or "") if ch.isdigit())
    if not digits:
        return set()
    core10 = digits[-10:] if len(digits) > 10 else digits
    variants: Set[str] = {digits, f"+{digits}", core10}
    if len(core10) == 10:
        variants.update({f"91{core10}", f"+91{core10}", f"0{core10}"})
    return {v for v in variants if v}


# ── source-DB helpers ─────────────────────────────────────────────────────────

def _open_source_conn(cfg: Dict) -> pymysql.Connection:
    return pymysql.connect(
        host=cfg["source_db_host"],
        port=int(cfg.get("source_db_port") or 3306),
        user=cfg["source_db_user"],
        password=cfg.get("source_db_password") or "",
        database=cfg["source_db_name"],
        charset="utf8mb4",
        cursorclass=DictCursor,
        connect_timeout=10,
    )


def _detect_source_table(conn: pymysql.Connection, bid: str) -> Optional[str]:
    candidates = [
        f"{bid}_callhistory",
        f"{bid}_call_history",
        f"{bid}_callarchive",
        f"{bid}_call_archive",
    ]
    cursor = conn.cursor()
    cursor.execute("SHOW TABLES")
    existing = {list(r.values())[0] for r in cursor.fetchall()}
    for t in candidates:
        if t in existing:
            return t
    return None


def _detect_source_columns(conn: pymysql.Connection, table: str) -> Set[str]:
    cursor = conn.cursor()
    cursor.execute(f"SHOW COLUMNS FROM `{table}`")
    return {r["Field"] for r in cursor.fetchall()}


def _parse_pipeline_datetime(value) -> Optional[datetime]:
    if value is None:
        return None
    if isinstance(value, datetime):
        return value.replace(tzinfo=None) if value.tzinfo else value
    try:
        parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
        return parsed.replace(tzinfo=None) if parsed.tzinfo else parsed
    except Exception:
        return None


from min_duration_util import is_below_min_duration as shared_is_below_min_duration


# ── Stage 1: SYNC ─────────────────────────────────────────────────────────────

def stage_sync(bid: str, db: DatabaseHandler, pipeline_cfg: Dict) -> int:
    """Pull new ANSWER calls from source DB and upsert into {bid}_call_records."""
    logger.info("[%s] === STAGE 1: SYNC ===", bid)

    batch_size: int = int(pipeline_cfg.get("sync_batch") or 500)
    min_duration: int = int(pipeline_cfg.get("min_call_duration_s") or 0)
    effective_at = pipeline_cfg.get("min_call_duration_effective_at")
    if min_duration > 0 and not effective_at:
        effective_at = db.ensure_min_duration_effective_at(bid)
    lead_filter: bool = bool(pipeline_cfg.get("lead_filter_enabled", True))
    crm_provider: str = pipeline_cfg.get("crm_provider") or "leadsquared"
    lookback_days: int = int(pipeline_cfg.get("lookback_days") or 90)

    # ── Ensure destination table exists ──────────────────────────────────────
    db.ensure_call_records_table(bid)

    # ── Load phone allow-list from CRM cache ─────────────────────────────────
    phone_set: Set[str] = set()
    if lead_filter:
        phone_set = db.get_lead_phone_set(bid, crm_provider)
        if not phone_set:
            logger.warning(
                "[%s] Lead filter enabled but crm_leads_cache is empty — "
                "run sync_crm_leads.py first (or disable lead_filter_enabled)",
                bid,
            )
            return 0
        logger.info("[%s] Loaded %d phone variants from CRM cache", bid, len(phone_set))

    # ── Sync watermark ────────────────────────────────────────────────────────
    watermark_key = f"call_sync_{bid}"
    raw_wm = db.get_sync_watermark(bid, watermark_key)
    if raw_wm:
        watermark = datetime.fromisoformat(str(raw_wm))
    else:
        watermark = datetime.now() - timedelta(days=lookback_days)
    logger.info("[%s] Fetching ANSWER calls since %s", bid, watermark.strftime("%Y-%m-%d %H:%M:%S"))

    # ── Connect to source ─────────────────────────────────────────────────────
    try:
        src_conn = _open_source_conn(pipeline_cfg)
    except Exception as exc:
        logger.error("[%s] Source DB connection failed: %s", bid, exc)
        return 0

    inserted = 0
    latest_starttime: Optional[datetime] = None

    try:
        src_table = _detect_source_table(src_conn, bid)
        if not src_table:
            logger.error("[%s] No source call table found in source DB", bid)
            return 0
        logger.info("[%s] Source table: %s", bid, src_table)

        avail_cols = _detect_source_columns(src_conn, src_table)
        fixed_cols = ["callid", "bid", "agentname", "groupname", "starttime", "endtime",
                      "dialstatus", "direction", "filename", "emp_phone"]
        optional_cols = ["callto", "callfrom", "clicktocalldid"]
        select_cols = fixed_cols + [c for c in optional_cols if c in avail_cols]
        cols_sql = ", ".join(f"`{c}`" for c in select_cols)

        src_cursor = src_conn.cursor()
        src_cursor.execute(
            f"SELECT {cols_sql} FROM `{src_table}` "
            f"WHERE dialstatus = 'ANSWER' AND starttime > %s "
            f"ORDER BY starttime ASC LIMIT %s",
            (watermark, batch_size),
        )
        calls = src_cursor.fetchall() or []

        if not calls:
            logger.info("[%s] No new ANSWER calls since watermark", bid)
            return 0

        logger.info("[%s] Fetched %d ANSWER calls from source", bid, len(calls))

        skipped = 0
        for call in calls:
            # ── Duration filter ───────────────────────────────────────────────
            start = call.get("starttime")
            end = call.get("endtime")
            duration_s: Optional[int] = None
            if start and end:
                try:
                    duration_s = int((end - start).total_seconds())
                except Exception:
                    duration_s = None
            if shared_is_below_min_duration(call, min_duration, effective_at):
                skipped += 1
                continue

            # ── Phone filter ──────────────────────────────────────────────────
            customer_phone = _pick_customer_phone(call)
            if lead_filter:
                variants = _phone_variants(customer_phone)
                if not variants.intersection(phone_set):
                    skipped += 1
                    continue

            direction = (call.get("direction") or "inbound").lower()
            agent_phone = str(call.get("emp_phone") or call.get("callfrom") or "").strip()
            file_url = str(call.get("filename") or "").strip()
            callid = str(call["callid"]).strip()

            # ── Determine initial status ──────────────────────────────────────
            # No file URL → cannot transcribe; mark pending but note it
            status = "pending" if file_url else "failed"
            fail_reason = None if file_url else "No audio file URL"

            record = {
                "bid": str(call.get("bid") or bid),
                "callid": callid,
                "file_url": file_url,
                "status": status,
                "fail_stage": None if file_url else "sync",
                "fail_reason": fail_reason,
                "agent_name": str(call.get("agentname") or ""),
                "group_name": str(call.get("groupname") or ""),
                "direction": direction,
                "call_start": start,
                "call_end": end,
                "call_duration_s": duration_s,
                "call_status": str(call.get("dialstatus") or ""),
                "agent_phone": agent_phone,
                "customer_phone": customer_phone,
            }

            db.upsert_call_record(bid, record)
            inserted += 1

            if start and (latest_starttime is None or start > latest_starttime):
                latest_starttime = start

        logger.info(
            "[%s] Sync: %d upserted, %d skipped out of %d fetched",
            bid, inserted, skipped, len(calls),
        )

        # ── Advance watermark ─────────────────────────────────────────────────
        if latest_starttime:
            # 1-hour overlap buffer to catch late-arriving records
            new_wm = latest_starttime - timedelta(hours=1)
            db.set_sync_watermark(bid, watermark_key, new_wm.isoformat())
            logger.info("[%s] Watermark advanced to %s", bid, new_wm)

    except Exception as exc:
        logger.error("[%s] Sync error: %s", bid, exc, exc_info=True)
    finally:
        src_conn.close()

    return inserted


# ── Stage 2: TRANSCRIBE ───────────────────────────────────────────────────────

def stage_transcribe(bid: str, db: DatabaseHandler, pipeline_cfg: Dict) -> int:
    """Transcribe pending calls and update {bid}_call_records."""
    logger.info("[%s] === STAGE 2: TRANSCRIBE ===", bid)

    stt_provider_name: str = pipeline_cfg.get("stt_provider") or "sarvam"
    stt_api_key: str = pipeline_cfg.get("stt_api_key") or ""
    batch: int = int(pipeline_cfg.get("transcribe_batch") or 3)

    try:
        stt = get_stt_provider(stt_provider_name, stt_api_key)
    except ValueError as exc:
        logger.error("[%s] Cannot load STT provider: %s", bid, exc)
        return 0

    min_duration: int = int(pipeline_cfg.get("min_call_duration_s") or 0)
    effective_at = pipeline_cfg.get("min_call_duration_effective_at")
    if min_duration > 0 and not effective_at:
        effective_at = db.ensure_min_duration_effective_at(bid)
    pending = db.get_calls_to_transcribe(
        bid,
        batch=batch,
        min_duration_s=min_duration,
        effective_at=effective_at,
    )
    if not pending:
        logger.info("[%s] No calls pending transcription", bid)
        return 0

    logger.info("[%s] %d call(s) to transcribe", bid, len(pending))
    success = 0

    for call in pending:
        if _shutdown:
            break
        callid = call["callid"]
        file_url = call.get("file_url") or ""

        if not file_url:
            db.fail_call(bid, callid, stage="transcribe", reason="No audio URL")
            continue

        logger.info("[%s] Transcribing %s", bid, callid)
        db.set_call_status(bid, callid, "transcribing")

        try:
            stt_result = stt.transcribe(file_url, callid)
            if not stt_result.transcript:
                raise RuntimeError("Empty transcript returned")
            db.save_call_transcription(bid, callid, stt_result)
            success += 1
            logger.info("[%s] Transcribed %s — %d chars", bid, callid, len(stt_result.transcript))
        except Exception as exc:
            logger.error("[%s] Transcription failed for %s: %s", bid, callid, exc)
            db.fail_call(bid, callid, stage="transcribe", reason=str(exc))

    logger.info("[%s] Transcription: %d/%d succeeded", bid, success, len(pending))
    return success


# ── Stage 3: ANALYZE ──────────────────────────────────────────────────────────

def stage_analyze(bid: str, db: DatabaseHandler, agent_runner: AgentRunner) -> int:
    """Run AI agents on transcribed calls and save analysis."""
    logger.info("[%s] === STAGE 3: ANALYZE ===", bid)

    batch = 3
    pending = db.get_calls_to_analyze(bid, batch=batch)
    if not pending:
        logger.info("[%s] No calls pending analysis", bid)
        return 0

    logger.info("[%s] %d call(s) to analyze", bid, len(pending))
    success = 0

    for call in pending:
        if _shutdown:
            break
        callid = call["callid"]
        transcript = call.get("transcript") or ""
        speaker_segments: List[Dict] = []
        raw_segs = call.get("speaker_segments")
        if raw_segs:
            try:
                speaker_segments = json.loads(raw_segs) if isinstance(raw_segs, str) else raw_segs
            except Exception:
                speaker_segments = []

        call_meta = {
            "agent_name": call.get("agent_name") or "",
            "customer_phone": call.get("customer_phone") or "",
            "call_start": str(call.get("call_start") or ""),
            "call_duration_s": str(call.get("call_duration_s") or ""),
        }

        logger.info("[%s] Analyzing %s", bid, callid)
        db.set_call_status(bid, callid, "analyzing")

        try:
            analysis = agent_runner.run(
                bid=bid,
                callid=callid,
                transcript=transcript,
                speaker_segments=speaker_segments,
                call_metadata=call_meta,
            )

            if not analysis.get("agents_run"):
                # No agents configured — mark done without analysis
                db.save_call_analysis(bid, callid, {})
            else:
                db.save_call_analysis(bid, callid, analysis)
                try:
                    _push_leadsquared_activities(bid, call, analysis, db)
                except Exception as exc:
                    logger.error("[%s] LeadSquared push failed for %s: %s", bid, callid, exc)

            success += 1
            score = analysis.get("quality_score")
            logger.info(
                "[%s] Analyzed %s — score=%s agents=%s",
                bid, callid, score, analysis.get("agents_run"),
            )
        except Exception as exc:
            logger.error("[%s] Analysis failed for %s: %s", bid, callid, exc)
            db.fail_call(bid, callid, stage="analyze", reason=str(exc))

    logger.info("[%s] Analysis: %d/%d succeeded", bid, success, len(pending))
    return success


# ── Per-bid pipeline run ───────────────────────────────────────────────────────

def _decrypt_stt_key(db: DatabaseHandler, cfg: Dict) -> str:
    """Decrypt the STT API key stored in pipeline config, falling back to env."""
    encrypted = cfg.get("stt_api_key_enc") or ""
    if encrypted:
        try:
            return db._decrypt_text(encrypted) or ""
        except Exception:
            pass
    return os.getenv("SARVAM_SUBSCRIPTION_KEY") or os.getenv("SARVAM_PIPELINE_KEY") or ""


def _decrypt_source_password(db: DatabaseHandler, cfg: Dict) -> str:
    encrypted = cfg.get("source_db_password_enc") or ""
    if encrypted:
        try:
            return db._decrypt_text(encrypted) or ""
        except Exception:
            pass
    return os.getenv("SYNC_SOURCE_DB_PASSWORD") or ""


def run_bid(
    bid: str,
    db: DatabaseHandler,
    agent_runner: AgentRunner,
    stage_filter: Optional[str] = None,
) -> Dict[str, int]:
    """Run the full pipeline (or a single stage) for one bid.

    Returns counts: {synced, transcribed, analyzed}.
    """
    cfg = db.get_pipeline_config(bid)
    if not cfg:
        logger.warning("[%s] No pipeline config found — skipping", bid)
        return {}
    if not cfg.get("pipeline_enabled", True):
        logger.info("[%s] Pipeline disabled — skipping", bid)
        return {}

    # Merge decrypted secrets into config copy
    cfg = dict(cfg)
    cfg["stt_api_key"] = _decrypt_stt_key(db, cfg)
    cfg["source_db_password"] = _decrypt_source_password(db, cfg)

    counts: Dict[str, int] = {}

    if not stage_filter or stage_filter == "sync":
        counts["synced"] = stage_sync(bid, db, cfg)

    if _shutdown:
        return counts

    if not stage_filter or stage_filter == "transcribe":
        counts["transcribed"] = stage_transcribe(bid, db, cfg)

    if _shutdown:
        return counts

    if not stage_filter or stage_filter == "analyze":
        counts["analyzed"] = stage_analyze(bid, db, agent_runner)

    return counts


# ── Main ──────────────────────────────────────────────────────────────────────

def _build_config() -> Dict[str, Any]:
    from config import Config
    cfg: Dict[str, Any] = {}
    for key in dir(Config):
        if key.isupper():
            cfg[key] = getattr(Config, key)
    return cfg


def main() -> None:
    parser = argparse.ArgumentParser(description="Generic multi-bid call pipeline")
    parser.add_argument("--bid", help="Process only this bid (default: all enabled)")
    parser.add_argument(
        "--stage",
        choices=["sync", "transcribe", "analyze"],
        help="Run a single stage only",
    )
    parser.add_argument("--continuous", action="store_true", help="Loop indefinitely")
    parser.add_argument(
        "--interval", type=int, default=120,
        help="Seconds between pipeline runs (default: 120)",
    )
    args = parser.parse_args()

    config = _build_config()
    db = DatabaseHandler(config)
    runner = AgentRunner(db)

    # Ensure meta-tables exist
    db.ensure_business_pipeline_config_table()
    db.ensure_business_agent_config_table()

    iteration = 0
    while True:
        iteration += 1
        if _shutdown:
            break

        logger.info("")
        logger.info("=" * 70)
        logger.info("  PIPELINE RUN #%d — %s", iteration, datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        logger.info("=" * 70)

        try:
            if args.bid:
                bids = [args.bid]
            else:
                bids = db.get_enabled_pipeline_bids()

            if not bids:
                logger.info("No enabled pipeline configs found")
            else:
                for bid in bids:
                    if _shutdown:
                        break
                    try:
                        counts = run_bid(bid, db, runner, stage_filter=args.stage)
                        logger.info("[%s] Run complete: %s", bid, counts)
                    except Exception as exc:
                        logger.error("[%s] Unhandled pipeline error: %s", bid, exc, exc_info=True)

        except Exception as exc:
            logger.error("Pipeline iteration %d error: %s", iteration, exc, exc_info=True)

        if not args.continuous:
            break

        logger.info("Waiting %ds before next run…", args.interval)
        for _ in range(args.interval):
            if _shutdown:
                break
            time.sleep(1)

    logger.info("call_processor stopped.")


if __name__ == "__main__":
    main()
