#!/usr/bin/env python3
"""
Unified Call Processing Pipeline.

This module combines the polling-based sync from call_processor.py 
and the orchestration logic from orchestrate_pipeline.py into a single, 
RabbitMQ-powered system.

Modes:
  --mode orchestrator : Scans for new calls, syncs them, and queues tasks in RabbitMQ.
  --mode worker       : Consumes tasks from RabbitMQ and performs Transcription/Analysis.

Usage:
  python3 unified_pipeline.py --mode orchestrator --bid 1713
  python3 unified_pipeline.py --mode worker
"""

import argparse
import json
import logging
import os
import signal
import sys
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set

import pika
import pymysql
from pymysql.cursors import DictCursor
from dotenv import load_dotenv

# Bootstrap paths
BACKEND_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BACKEND_DIR)
os.chdir(BACKEND_DIR)

load_dotenv()

from config import Config
from db_handler import DatabaseHandler
from stt import get_stt_provider
from agent_runner import AgentRunner

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
    handlers=[
        logging.FileHandler(os.path.join(BACKEND_DIR, "unified_pipeline.log")),
        logging.StreamHandler(),
    ],
)
logger = logging.getLogger("unified_pipeline")

# Graceful shutdown handler
_shutdown = False

def _handle_signal(sig, frame):
    global _shutdown
    logger.info("Shutdown signal received - finishing current task...")
    _shutdown = True

signal.signal(signal.SIGTERM, _handle_signal)
signal.signal(signal.SIGINT, _handle_signal)

class SharedPipelineBase:
    """Base class for shared DB and RabbitMQ resources."""
    def __init__(self):
        self.config = Config()
        
        # Wrapped config for DatabaseHandler compatibility if needed
        class ConfigWrapper:
            def __init__(self, cfg):
                self._cfg = cfg
            def get(self, key, default=None):
                return getattr(self._cfg, key, default)
            def __getattr__(self, key):
                return getattr(self._cfg, key)

        self.db = DatabaseHandler(ConfigWrapper(self.config))
        self.rabbitmq_conn = None
        self.rabbitmq_channel = None

    def connect_rabbitmq(self):
        """Establish connection to RabbitMQ."""
        try:
            credentials = pika.PlainCredentials(self.config.RABBITMQ_USER, self.config.RABBITMQ_PASS)
            parameters = pika.ConnectionParameters(
                host=self.config.RABBITMQ_HOST,
                credentials=credentials,
                heartbeat=self.config.RABBITMQ_HEARTBEAT,
                blocked_connection_timeout=self.config.RABBITMQ_TIMEOUT
            )
            self.rabbitmq_conn = pika.BlockingConnection(parameters)
            self.rabbitmq_channel = self.rabbitmq_conn.channel()
            
            # Declare common queues
            for stage in ["transcription", "analysis"]:
                queue_name = f"{self.config.RABBITMQ_QUEUE_PREFIX}{stage}"
                self.rabbitmq_channel.queue_declare(queue=queue_name, durable=True)
            
            logger.info(f"Connected to RabbitMQ at {self.config.RABBITMQ_HOST}")
        except Exception as e:
            logger.error(f"RabbitMQ connection failed: {e}")
            raise

    def close_rabbitmq(self):
        if self.rabbitmq_conn and not self.rabbitmq_conn.is_closed:
            self.rabbitmq_conn.close()

class Orchestrator(SharedPipelineBase):
    """Responsible for syncing calls and queuing tasks."""
    
    def __init__(self):
        super().__init__()
        self.connect_rabbitmq()

    def run_bid(self, bid: str, limit: int = 50):
        """Run the orchestrator steps for a single bid."""
        logger.info(f"[{bid}] Starting orchestrator run")
        
        # 0. Sync source DB
        cfg = self.db.get_pipeline_config(bid)
        if not cfg or not cfg.get("pipeline_enabled", True):
            logger.warning(f"[{bid}] Pipeline disabled or missing config - skipping sync")
        else:
            self.sync_calls(bid, cfg, limit)
        
        # 1. Queue Transcription Tasks (status='pending')
        self.queue_tasks(bid, "transcription", limit=limit)
        
        # 2. Queue Analysis Tasks (status='transcribed')
        self.queue_tasks(bid, "analysis", limit=limit)
        
        # 3. Recover stuck calls
        self.recover_stuck_calls(bid)

    def _get_reconstructed_url(self, bid: str, filename: str, starttime: Any) -> str:
        """Reconstruct the full recording URL if it's just a filename (Logic from orchestrate_pipeline.py)."""
        if not filename: return ""
        if filename.startswith('http'): return filename
        if '/' in filename:
            path = filename.lstrip('/')
            return f"https://recordings.mcube.com/{path}"
        
        try:
            dt = datetime.fromisoformat(str(starttime)) if isinstance(starttime, str) else starttime
            year, month = dt.strftime("%Y"), dt.strftime("%m")
            return f"https://recordings.mcube.com/mcubefiles112/classic/{year}/{month}/{bid}/inbound/{filename}"
        except Exception as e:
            logger.error(f"[{bid}] Error reconstructing URL for {filename}: {e}")
            return filename

    def sync_calls(self, bid: str, pipeline_cfg: Dict, limit: int):
        """Pull new calls from source DB into {bid}_call_records."""
        from call_processor import stage_sync
        # Call the existing sync logic
        inserted = stage_sync(bid, self.db, pipeline_cfg)
        
        # Post-sync URL reconstruction for bare filenames
        if inserted > 0:
            with self.db.get_connection() as conn:
                cursor = conn.cursor()
                table = f"{bid}_call_records"
                # Find records with status='pending' and potentially incomplete URLs
                cursor.execute(f"SELECT callid, file_url, call_start FROM `{table}` WHERE status = 'pending' AND file_url NOT LIKE 'http%'")
                to_repair = cursor.fetchall() or []
                for row in to_repair:
                    repaired = self._get_reconstructed_url(bid, row['file_url'], row['call_start'])
                    if repaired != row['file_url']:
                        cursor.execute(f"UPDATE `{table}` SET file_url = %s WHERE callid = %s", (repaired, row['callid']))
            logger.info(f"[{bid}] URL reconstruction complete for {len(to_repair)} records")

    def queue_tasks(self, bid: str, stage: str, limit: int):
        """Find calls in the DB and push them to RabbitMQ."""
        queue_name = f"{self.config.RABBITMQ_QUEUE_PREFIX}{stage}"
        
        if stage == "transcription":
            calls = self.db.get_calls_to_transcribe(bid, batch=limit)
        else:
            calls = self.db.get_calls_to_analyze(bid, batch=limit)
            
        if not calls:
            return

        logger.info(f"[{bid}] Queuing {len(calls)} calls for {stage}")
        for call in calls:
            payload = {
                "bid": bid,
                "callid": call["callid"],
                "stage": stage,
                "timestamp": datetime.now().isoformat()
            }
            self.rabbitmq_channel.basic_publish(
                exchange="",
                routing_key=queue_name,
                body=json.dumps(payload),
                properties=pika.BasicProperties(delivery_mode=2)
            )
            
            # Update status to avoid re-queuing in multi-instance orchestrator scenarios
            # Use 'transcribing' or 'analyzing' as temporary "queued" states
            new_status = "transcribing" if stage == "transcription" else "analyzing"
            self.db.set_call_status(bid, call["callid"], new_status)

    def recover_stuck_calls(self, bid: str):
        """Resets calls stuck in 'transcribing' or 'analyzing' for too long."""
        table = f"{bid}_call_records"
        # Reset calls stuck in transcribing/analyzing for more than 4 hours
        stuck_time = datetime.now() - timedelta(hours=4)
        
        with self.db.get_connection() as conn:
            cursor = conn.cursor()
            # Reset transcribing -> pending if no transcript exists
            cursor.execute(f"""
                UPDATE `{table}` 
                SET status = 'pending' 
                WHERE status = 'transcribing' 
                AND (transcript IS NULL OR transcript = '')
                AND updated_at < %s
            """, (stuck_time,))
            t_recovered = cursor.rowcount
            
            # Reset analyzing -> transcribed if analyze results are empty
            # (Simplification: just reset to transcribed, worker will pick up)
            cursor.execute(f"""
                UPDATE `{table}` 
                SET status = 'transcribed' 
                WHERE status = 'analyzing' 
                AND updated_at < %s
            """, (stuck_time,))
            a_recovered = cursor.rowcount
            
            if t_recovered or a_recovered:
                logger.info(f"[{bid}] Recovered {t_recovered} transcription and {a_recovered} analysis tasks")

class Worker(SharedPipelineBase):
    """Responsible for processing tasks from RabbitMQ."""
    
    def __init__(self):
        super().__init__()
        self.agent_runner = AgentRunner(self.db)
        self.connect_rabbitmq()

    def start(self):
        """Start consuming from all queues."""
        for stage in ["transcription", "analysis"]:
            queue_name = f"{self.config.RABBITMQ_QUEUE_PREFIX}{stage}"
            self.rabbitmq_channel.basic_consume(
                queue=queue_name,
                on_message_callback=self.process_message
            )
        
        logger.info("Worker started - waiting for tasks...")
        try:
            self.rabbitmq_channel.start_consuming()
        except KeyboardInterrupt:
            self.rabbitmq_channel.stop_consuming()

    def process_message(self, ch, method, properties, body):
        """Dispatches the message to the correct processing logic."""
        try:
            task = json.loads(body)
            bid = task["bid"]
            callid = task["callid"]
            stage = task["stage"]
            
            logger.info(f"[{bid}] Processing {stage} for call {callid}")
            
            success = False
            if stage == "transcription":
                success = self.do_transcription(bid, callid)
            elif stage == "analysis":
                success = self.do_analysis(bid, callid)
            
            if success:
                ch.basic_ack(delivery_tag=method.delivery_tag)
                logger.info(f"[{bid}] Successfully processed {stage} for {callid}")
            else:
                ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
                logger.error(f"[{bid}] Failed to process {stage} for {callid}")
        except Exception as e:
            logger.error(f"Critical error in worker callback: {e}", exc_info=True)
            if method.delivery_tag:
                ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)

    def do_transcription(self, bid: str, callid: str) -> bool:
        """Logic from call_processor.stage_transcribe."""
        cfg = self.db.get_pipeline_config(bid)
        if not cfg: return False
        
        stt_provider_name = cfg.get("stt_provider") or "sarvam"
        from call_processor import _decrypt_stt_key
        api_key = _decrypt_stt_key(self.db, cfg)
        
        try:
            stt = get_stt_provider(stt_provider_name, api_key)
        except Exception as e:
            logger.error(f"[{bid}] STT provider init failed: {e}")
            return False
        
        # Get call details using the correct method
        call = self.db.get_call_record_detail(bid, callid)
        if not call or not call.get("file_url"):
            self.db.fail_call(bid, callid, stage="transcribe", reason="No audio URL or record missing")
            return False

        self.db.set_call_status(bid, callid, "transcribing")
        try:
            result = stt.transcribe(call["file_url"], callid)
            if result and result.transcript:
                self.db.save_call_transcription(bid, callid, result)
                return True
            else:
                raise RuntimeError("Empty transcript returned")
        except Exception as exc:
            logger.error(f"[{bid}] Transcription failed for {callid}: {exc}")
            self.db.fail_call(bid, callid, stage="transcribe", reason=str(exc))
            return False

    def do_analysis(self, bid: str, callid: str) -> bool:
        """Logic from call_processor.stage_analyze."""
        # Fetch call using the correct method
        call = self.db.get_call_record_detail(bid, callid)
        if not call: 
            logger.error(f"[{bid}] Call record {callid} not found for analysis")
            return False
        
        transcript = call.get("transcript") or ""
        if not transcript:
            logger.warning(f"[{bid}] Call {callid} has no transcript for analysis")
            return False

        speaker_segments = call.get("speaker_segments") or []
        # get_call_record_detail already deserializes JSON columns
        
        call_meta = {
            "agent_name": call.get("agent_name") or "",
            "customer_phone": call.get("customer_phone") or "",
            "call_start": str(call.get("call_start") or ""),
            "call_duration_s": str(call.get("call_duration_s") or ""),
        }

        self.db.set_call_status(bid, callid, "analyzing")
        try:
            analysis = self.agent_runner.run(
                bid=bid,
                callid=callid,
                transcript=transcript,
                speaker_segments=speaker_segments,
                call_metadata=call_meta,
            )
            
            if analysis:
                self.db.save_call_analysis(bid, callid, analysis)
                return True
            return False
        except Exception as exc:
            logger.error(f"[{bid}] Analysis failed for {callid}: {exc}")
            self.db.fail_call(bid, callid, stage="analyze", reason=str(exc))
            return False

def main():
    parser = argparse.ArgumentParser(description="Unified Call Processing Pipeline")
    parser.add_argument("--mode", choices=["orchestrator", "worker"], required=True, 
                        help="Mode: orchestrator (queue tasks) or worker (process tasks)")
    parser.add_argument("--bid", help="Orchestrator mode: only this bid")
    parser.add_argument("--limit", type=int, default=50, help="Batch limit for sync and queuing")
    parser.add_argument("--interval", type=int, default=60, help="Loop interval for orchestrator")
    parser.add_argument("--continuous", action="store_true", help="Loop orchestrator indefinitely")

    args = parser.parse_args()

    if args.mode == "orchestrator":
        orchestrator = Orchestrator()
        while True:
            bids = [args.bid] if args.bid else orchestrator.db.get_enabled_pipeline_bids()
            for bid in bids:
                if _shutdown: break
                try:
                    orchestrator.run_bid(bid, limit=args.limit)
                except Exception as e:
                    logger.error(f"Orchestrator error for bid {bid}: {e}", exc_info=True)
            
            if not args.continuous or _shutdown: break
            time.sleep(args.interval)
        orchestrator.close_rabbitmq()

    elif args.mode == "worker":
        worker = Worker()
        worker.start()
        worker.close_rabbitmq()

if __name__ == "__main__":
    main()
