
import pymysql
import os
import json
import time
import logging
import subprocess
from dotenv import load_dotenv
from pymysql.cursors import DictCursor
import sys
from datetime import datetime, timedelta

# Since we are in dashboard-backend/, we can import directly
from config import Config
from analyze_calls_with_parameters import CallAnalyzer
from db_handler import DatabaseHandler
from stt.sarvam import SarvamSTT

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('orchestration.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Dedicated logger for analytics records
analytics_logger = logging.getLogger('analytics_updates')
analytics_logger.setLevel(logging.INFO)
ah = logging.FileHandler('analytics_updates.log')
ah.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
analytics_logger.addHandler(ah)

class Orchestrator:
    def __init__(self, bid):
        self.bid = bid
        self.config = Config()
        
        # Wrapped config for DatabaseHandler compatibility
        class ConfigWrapper:
            def __init__(self, config):
                self._config = config
            def get(self, key, default=None):
                return getattr(self._config, key, default)
            def __getattr__(self, key):
                return getattr(self._config, key)

        self.config_wrapped = ConfigWrapper(self.config)
        self.db_handler = DatabaseHandler(self.config_wrapped)
        self.analyzer = CallAnalyzer(self.config_wrapped)

        # Inline Sarvam STT (replaces RabbitMQ worker dependency)
        sarvam_key = os.getenv('SARVAM_PIPELINE_KEY') or os.getenv('SARVAM_SUBSCRIPTION_KEY', '')
        self.stt = SarvamSTT(api_key=sarvam_key) if sarvam_key else None

    def get_db_connection(self):
        return pymysql.connect(
            host=self.config.DB_HOST,
            port=self.config.DB_PORT,
            user=self.config.DB_USER,
            password=self.config.DB_PASSWORD,
            database=self.config.DB_NAME,
            cursorclass=DictCursor,
            autocommit=True
        )
    
    def _get_full_url(self, filename, starttime):
        """Reconstruct the full recording URL if it's just a filename."""
        if not filename:
            return ""
        
        # 1. If it's already a full URL, return it
        if filename.startswith('http'):
            return filename
            
        # 2. If it contains a path structure, just prepend the base domain
        if '/' in filename:
            # Ensure we don't have leading slashes if prepending
            path = filename.lstrip('/')
            return f"https://recordings.mcube.com/{path}"
            
        # 3. Bare filename: Reconstruction logic
        # Pattern: https://recordings.mcube.com/mcubefiles112/classic/{year}/{month}/{bid}/inbound/{filename}
        try:
            if isinstance(starttime, str):
                dt = datetime.strptime(starttime, "%Y-%m-%d %H:%M:%S")
            else:
                dt = starttime
                
            year = dt.strftime("%Y")
            month = dt.strftime("%m")
            
            # Use mcubefiles112 as the standard cluster for bare filenames
            base_url = f"https://recordings.mcube.com/mcubefiles112/classic/{year}/{month}/{self.bid}/inbound/{filename}"
            return base_url
        except Exception as e:
            logger.error(f"Error reconstructing URL: {e}")
            return filename


    def get_source_db_connection(self):
        return pymysql.connect(
            host=self.config.SYNC_SOURCE_DB_HOST,
            port=self.config.SYNC_SOURCE_DB_PORT,
            user=self.config.SYNC_SOURCE_DB_USER,
            password=self.config.SYNC_SOURCE_DB_PASSWORD,
            database=self.config.SYNC_SOURCE_DB_NAME,
            cursorclass=DictCursor
        )

    # Maps source `source` column values to a direction string.
    # 1713_callhistory has no direction field — direction is inferred from the
    # call-type recorded by Mcube: calltrack and ivrs are both inbound flows
    # (customer dials the landing number). Extend this dict if other BIDs have
    # outbound source types (e.g. 'click2call', 'preview', 'progressive').
    _SOURCE_DIRECTION_MAP = {
        'calltrack': 'inbound',
        'ivrs':      'inbound',
    }

    @staticmethod
    def _derive_direction(source_value):
        src = (source_value or '').strip().lower()
        direction = Orchestrator._SOURCE_DIRECTION_MAP.get(src)
        if direction is None:
            logger.warning(f"Unknown source type '{source_value}' — defaulting direction to 'inbound'. "
                           "Add it to _SOURCE_DIRECTION_MAP if it is outbound.")
            direction = 'inbound'
        return direction

    # How far back to rewind the watermark to catch near-real-time delays and
    # minor clock skew between source and destination DBs.
    _WATERMARK_OVERLAP_MINUTES = 10

    def ingest_calls(self, limit=0, ignore_watermark=False):
        """Fetch answered calls from source DB and upsert into local raw_calls.

        limit=0 means no limit — all new calls since the watermark are ingested
        in a single pass. This is intentional: ingest is a cheap DB-to-DB copy
        and should never leave new calls behind. Rate-limiting happens in the
        transcription and analytics phases instead.
        """
        current_time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        logger.info(f"Orchestration initiated at {current_time_str}")
        logger.info(f"Logs are being updated in orchestration.log and analytics_updates.log")
        logger.info(f"Ingest limit: {'unlimited' if not limit else limit}")

        source_table = f"{self.bid}_callhistory"

        try:
            watermark = None
            if not ignore_watermark:
                # 1. Get watermark: MAX(call_starttime) minus an overlap buffer so
                #    that calls written to the source DB with a slight delay are not
                #    permanently skipped. ON DUPLICATE KEY UPDATE makes re-reading
                #    overlapping rows safe and idempotent.
                dest_conn = self.get_db_connection()
                try:
                    with dest_conn.cursor() as dest_cursor:
                        dest_cursor.execute(f"SELECT MAX(call_starttime) as last_start FROM `{self.bid}_raw_calls`")
                        result = dest_cursor.fetchone()
                        if result and result['last_start']:
                            raw_watermark = result['last_start']
                            watermark = raw_watermark - timedelta(minutes=self._WATERMARK_OVERLAP_MINUTES)
                            logger.info(f"Local watermark: {raw_watermark} → rewound by "
                                        f"{self._WATERMARK_OVERLAP_MINUTES} min to {watermark}")
                        else:
                            watermark = datetime(2026, 2, 1, 0, 0, 0)
                            logger.info(f"No local records found. Starting from: {watermark}")
                finally:
                    dest_conn.close()
            else:
                logger.info("Ignoring local watermark for ingestion.")

            # 2. Fetch answered calls from source newer than the watermark.
            #    Field mapping (confirmed against live 1713_callhistory data):
            #      c.callfrom = customer's mobile number (the party that dialled in)
            #      c.callto   = agent's internal extension (the party that answered)
            #      c.source   = call type (calltrack / ivrs) — used to derive direction
            #      c.landingnumber = the DID/toll-free number the customer dialled
            source_conn = self.get_source_db_connection()
            try:
                with source_conn.cursor() as source_cursor:
                    limit_clause = f"LIMIT {limit}" if limit and limit > 0 else ""
                    where_clause = "WHERE c.starttime > %s" if watermark else "WHERE 1=1"

                    query = f"""
                        SELECT
                            c.callid,
                            c.bid,
                            e.empname        AS agentname,
                            g.groupname      AS groupname,
                            c.starttime,
                            c.endtime,
                            c.dialstatus,
                            c.filename,
                            c.source,
                            c.landingnumber,
                            c.callfrom       AS customer_phone,
                            c.callto         AS agent_phone
                        FROM {source_table} c
                        LEFT JOIN {self.bid}_employee e ON c.assignto = e.eid
                        LEFT JOIN {self.bid}_groups   g ON c.gid      = g.gid
                        {where_clause}
                        AND c.dialstatus IN ('ANSWER')
                        AND c.starttime >= '2026-02-01'
                        ORDER BY c.starttime ASC
                        {limit_clause}
                    """
                    if watermark:
                        source_cursor.execute(query, (watermark,))
                    else:
                        source_cursor.execute(query)
                    calls = source_cursor.fetchall()
            finally:
                source_conn.close()

            logger.info(f"Records fetched from source: {len(calls)}")
            if calls:
                logger.info(f"Call IDs to ingest: {', '.join(str(c['callid']) for c in calls)}")

            if not calls:
                return 0

            # 3. Upsert into destination.
            #    agent_callinfo   ← c.callto   (agent's extension, e.g. 7055021253)
            #    customer_callinfo ← c.callfrom (customer's mobile, e.g. 9811209411)
            #    direction        ← derived from c.source (not hardcoded)
            dest_conn = self.get_db_connection()
            try:
                with dest_conn.cursor() as dest_cursor:
                    insert_query = f"""
                        INSERT INTO `{self.bid}_raw_calls`
                        (bid, callid, fileurl, status, agentname, groupname, call_starttime, call_endtime,
                         call_status, agent_callinfo, customer_callinfo, direction, extra_fields,
                         transcription_requested, transcription_status, selected_for_processing)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                        fileurl           = VALUES(fileurl),
                        agentname         = VALUES(agentname),
                        groupname         = VALUES(groupname),
                        call_starttime    = VALUES(call_starttime),
                        call_endtime      = VALUES(call_endtime),
                        call_status       = VALUES(call_status),
                        agent_callinfo    = VALUES(agent_callinfo),
                        customer_callinfo = VALUES(customer_callinfo),
                        direction         = VALUES(direction),
                        extra_fields      = VALUES(extra_fields)
                    """

                    inserted = 0
                    for call in calls:
                        direction = self._derive_direction(call.get('source'))
                        extra = json.dumps({
                            'landing_number': call.get('landingnumber') or '',
                            'source_type':    call.get('source') or '',
                        })
                        dest_cursor.execute(insert_query, (
                            call['bid'],
                            call['callid'],
                            self._get_full_url(call['filename'], call['starttime']),
                            0,                                                   # status = ingested
                            str(call['agentname'])  if call['agentname']  else '',
                            str(call['groupname'])  if call['groupname']  else '',
                            call['starttime'],
                            call['endtime'],
                            call['dialstatus'] or '',
                            call.get('agent_phone')    or '',                    # callto  → agent extension
                            call.get('customer_phone') or '',                    # callfrom → customer mobile
                            direction,
                            extra,
                            None,  # transcription_requested
                            None,  # transcription_status
                            None,  # selected_for_processing
                        ))
                        if dest_cursor.rowcount >= 1:
                            inserted += 1

                    logger.info(f"Successfully ingested {inserted} calls into local DB.")
                    return inserted
            finally:
                dest_conn.close()

        except Exception as e:
            logger.error(f"Error during ingestion: {e}")
            return 0

    def validate_url(self, url):
        try:
            res = subprocess.run(['curl', '-I', '-s', url], capture_output=True, text=True)
            return '200 OK' in res.stdout
        except:
            return False

    def trigger_transcription(self, call):
        call_id = str(call['callid'])
        recording_url = call['fileurl']

        logger.info(f"[{call_id}] Triggering transcription...")

        conn = self.get_db_connection()
        try:
            with conn.cursor() as cursor:
                # 1. Idempotency check: already transcribed?
                resp_table = f"`{self.bid}_sarvamresponse`"
                cursor.execute(f"SELECT id FROM {resp_table} WHERE callid = %s", (call_id,))
                if cursor.fetchone():
                    logger.info(f"[{call_id}] Transcript already exists. Advancing to status=2.")
                    cursor.execute(
                        f"UPDATE `{self.bid}_raw_calls` SET status = 2 WHERE callid = %s", (call_id,)
                    )
                    return True

                if not self.stt:
                    logger.error(f"[{call_id}] No STT provider configured (missing SARVAM_PIPELINE_KEY).")
                    return False

                # 2. Mark as in-progress (status=1)
                cursor.execute(
                    f"UPDATE `{self.bid}_raw_calls` SET status = 1 WHERE callid = %s", (call_id,)
                )
        finally:
            conn.close()

        # 3. Run Sarvam transcription inline
        try:
            result = self.stt.transcribe(audio_url=recording_url, callid=call_id)
        except Exception as exc:
            logger.error(f"[{call_id}] Transcription failed (transient, will retry): {exc}")
            conn = self.get_db_connection()
            try:
                with conn.cursor() as cursor:
                    # -2 = transient STT failure (rate limit, network error) — retried by Phase 0b next run
                    cursor.execute(
                        f"UPDATE `{self.bid}_raw_calls` SET status = -2 WHERE callid = %s", (call_id,)
                    )
            finally:
                conn.close()
            return False

        # 4. Save transcript to sarvamresponse and advance status to 2
        conn = self.get_db_connection()
        try:
            with conn.cursor() as cursor:
                segs_json = json.dumps(result.speaker_segments) if result.speaker_segments else None
                num_speakers = len(set(s.get('speaker_id', '') for s in (result.speaker_segments or [])))
                cursor.execute(
                    f"""INSERT INTO {resp_table}
                        (callid, transcript, speaker_segments, num_speakers, duration, stt_provider, request_id, raw_response, status)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, '1')
                        ON DUPLICATE KEY UPDATE
                            transcript = VALUES(transcript),
                            speaker_segments = VALUES(speaker_segments),
                            num_speakers = VALUES(num_speakers),
                            duration = VALUES(duration),
                            stt_provider = VALUES(stt_provider),
                            status = '1',
                            updated_at = NOW()""",
                    (
                        call_id,
                        result.transcript,
                        segs_json,
                        num_speakers,
                        result.duration,
                        result.provider,
                        call_id,
                        '',
                    ),
                )
                cursor.execute(
                    f"UPDATE `{self.bid}_raw_calls` SET status = 2 WHERE callid = %s", (call_id,)
                )
            logger.info(f"[{call_id}] Transcription saved. Ready for analytics.")
            self._increment_daily_count()
            return True
        finally:
            conn.close()

    def trigger_analytics(self, call_id):
        logger.info(f"[{call_id}] Triggering analytics...")
        try:
            # Refresh transcript from DB (uses {bid}_sarvamresponse)
            transcript_data = self.db_handler.get_call_transcript(self.bid, call_id)
            if not transcript_data or not transcript_data.get('transcript'):
                logger.error(f"[{call_id}] No transcript found in response table.")
                return False
            
            transcript = transcript_data['transcript']
            speaker_segments = transcript_data.get('speaker_segments')
            if speaker_segments and isinstance(speaker_segments, str):
                speaker_segments = json.loads(speaker_segments)
            duration = transcript_data.get('duration')

            # Run analyzer
            result = self.analyzer.analyze_call(
                bid=self.bid,
                callid=call_id,
                transcript=transcript,
                speaker_segments=speaker_segments or [],
                actual_duration=float(duration) if duration else None
            )
            
            logger.info(f"[{call_id}] Analytics complete. Quality Score: {result.get('quality_score')}%")
            
            # Update raw_calls status to 3 (analyzed)
            conn = self.get_db_connection()
            try:
                with conn.cursor() as cursor:
                    cursor.execute(f"UPDATE `{self.bid}_raw_calls` SET status = 3 WHERE callid = %s", (call_id,))
                
                # Log the successful analytics update as requested by the user
                analytics_logger.info(f"SUCCESS: Analytics created/updated in {self.bid}_callanalytics for callid={call_id} with Quality Score={result.get('quality_score')}")
                logger.info(f"Analytics is completed for record {call_id}")
            finally:
                conn.close()
                
            return True
        except Exception as e:
            logger.error(f"[{call_id}] Analytics failed: {e}")
            return False

    def _get_daily_count(self):
        path = f"/tmp/pcaa_{self.bid}_daily_{datetime.now().strftime('%Y%m%d')}.txt"
        try:
            with open(path) as f:
                return int(f.read().strip() or 0)
        except FileNotFoundError:
            return 0

    def _increment_daily_count(self):
        path = f"/tmp/pcaa_{self.bid}_daily_{datetime.now().strftime('%Y%m%d')}.txt"
        count = self._get_daily_count() + 1
        with open(path, 'w') as f:
            f.write(str(count))
        return count

    def run(self, ingest_limit=0, transcribe_limit=20, analyze_limit=20,
            max_per_day=0, ignore_watermark=False, skip_ingest=False):
        """Run one full orchestration cycle.

        Phase limits are intentionally separate:
          ingest_limit    — 0 = unlimited (ingest everything new in one pass)
          transcribe_limit — per-run cap to respect Sarvam STT rate limits
          analyze_limit    — per-run cap to respect LLM rate limits
          max_per_day      — 0 = unlimited daily transcription cap
        """
        logger.info(
            f"Starting orchestration for BID {self.bid} | "
            f"ingest_limit={'unlimited' if not ingest_limit else ingest_limit} | "
            f"transcribe_limit={transcribe_limit} | analyze_limit={analyze_limit} | "
            f"max_per_day={max_per_day} | "
            f"ignore_watermark={ignore_watermark} | skip_ingest={skip_ingest}"
        )

        # 0. Ingestion (skip for Mcube 2.0 BIDs — records are pushed via Call Sync tab)
        if not skip_ingest:
            self.ingest_calls(limit=ingest_limit, ignore_watermark=ignore_watermark)
        else:
            logger.info("Skipping ingest step (--skip-ingest). Processing records already in raw_calls table.")

        conn = self.get_db_connection()
        try:
            with conn.cursor() as cursor:
                # --- PHASE 0: STUCK CALL RECOVERY ---
                resp_table = f"`{self.bid}_sarvamresponse`"

                # 0a. Reset status=1 calls with no sarvamresponse row (worker was down)
                cursor.execute(f"""
                    UPDATE `{self.bid}_raw_calls` r
                    LEFT JOIN {resp_table} s ON r.callid = s.callid
                    SET r.status = 0
                    WHERE r.status = 1 AND s.callid IS NULL
                """)
                recovered = cursor.rowcount
                if recovered:
                    logger.info(f"Recovered {recovered} stuck status=1 calls back to status=0 for retry.")

                # 0b. Reset status=-2 (transient STT failure) calls back to 0 for retry
                cursor.execute(f"UPDATE `{self.bid}_raw_calls` SET status = 0 WHERE status = -2")
                retried = cursor.rowcount
                if retried:
                    logger.info(f"Queued {retried} transient-STT-failed (status=-2) calls back to status=0 for retry.")

                # 0c. Fix status=2 calls where sarvamresponse transcript is empty AND the job did
                #     not complete successfully. These are old RabbitMQ placeholder rows or jobs
                #     that errored before producing output.
                #     Explicitly exclude status='completed': if Sarvam finished cleanly but found
                #     no speech (tiny/silent recording), we accept that as a permanent result and
                #     do NOT retry, to avoid an infinite re-transcription loop.
                cursor.execute(f"""
                    SELECT r.callid FROM `{self.bid}_raw_calls` r
                    JOIN {resp_table} s ON r.callid = s.callid
                    WHERE r.status = 2
                      AND (s.transcript IS NULL OR s.transcript = '')
                      AND (s.status IS NULL OR s.status != 'completed')
                """)
                empty_transcript_calls = [row['callid'] for row in cursor.fetchall()]
                if empty_transcript_calls:
                    placeholders = ','.join(['%s'] * len(empty_transcript_calls))
                    cursor.execute(f"DELETE FROM {resp_table} WHERE callid IN ({placeholders}) AND (transcript IS NULL OR transcript = '')", empty_transcript_calls)
                    cursor.execute(f"UPDATE `{self.bid}_raw_calls` SET status = 0 WHERE callid IN ({placeholders})", empty_transcript_calls)
                    logger.info(f"Recovered {len(empty_transcript_calls)} status=2 calls with empty transcripts back to status=0 for re-transcription.")

                raw_table = f"`{self.bid}_raw_calls`"

                # --- DAILY LIMIT CHECK ---
                if max_per_day and max_per_day > 0:
                    daily_count = self._get_daily_count()
                    if daily_count >= max_per_day:
                        logger.info(f"Daily limit reached ({daily_count}/{max_per_day}). Skipping transcription and analytics for this run.")
                        return

                # --- PHASE 1: TRANSCRIPTION ---
                t_limit_clause = f"LIMIT {transcribe_limit}" if transcribe_limit and transcribe_limit > 0 else ""
                cursor.execute(
                    f"SELECT * FROM {raw_table} "
                    f"WHERE status = 0 AND fileurl IS NOT NULL AND fileurl != '' "
                    f"AND call_starttime >= '2026-02-01' "
                    f"ORDER BY call_starttime ASC {t_limit_clause}"
                )
                candidates = cursor.fetchall()

                valid_calls = []
                for call in candidates:
                    current_url = call['fileurl']
                    repaired_url = self._get_full_url(current_url, call['call_starttime'])
                    if repaired_url != current_url:
                        logger.info(f"[{call['callid']}] Repaired truncated URL: {repaired_url}")
                        cursor.execute(f"UPDATE `{self.bid}_raw_calls` SET fileurl = %s WHERE callid = %s", (repaired_url, call['callid']))
                        call['fileurl'] = repaired_url
                    if self.validate_url(call['fileurl']):
                        valid_calls.append(call)
                    else:
                        logger.warning(f"[{call['callid']}] Invalid file URL: {call['fileurl']}")
                        cursor.execute(f"UPDATE `{self.bid}_raw_calls` SET status = -1 WHERE callid = %s", (call['callid'],))

                logger.info(f"Found {len(valid_calls)} valid calls to queue for transcription.")
                for call in valid_calls:
                    self.trigger_transcription(call)
                    time.sleep(0.5)

                # --- PHASE 2: ANALYTICS TRIGGERING ---
                a_limit_clause = f"LIMIT {analyze_limit}" if analyze_limit and analyze_limit > 0 else ""
                cursor.execute(
                    f"SELECT callid FROM {raw_table} "
                    f"WHERE status = 2 AND call_starttime >= '2026-02-01' "
                    f"ORDER BY call_starttime ASC {a_limit_clause}"
                )
                analyzable_calls = cursor.fetchall()

                logger.info(f"Found {len(analyzable_calls)} calls ready for analytics.")
                for call in analyzable_calls:
                    call_id = str(call['callid'])
                    logger.info(f"Transcription is completed for record {call_id}")
                    self.trigger_analytics(call_id)

        finally:
            conn.close()

        logger.info("This instance of the orchestrator job is completed.")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="PCAA orchestration pipeline")
    parser.add_argument("--bid", default="1713")
    parser.add_argument("--ingest-limit", type=int, default=0,
                        help="Max calls to fetch from source per run. 0 = unlimited (default). "
                             "Ingest is a fast DB-to-DB copy — leave unlimited unless debugging.")
    parser.add_argument("--transcribe-limit", type=int, default=20,
                        help="Max calls to send to STT per run (default 20). Keeps Sarvam within rate limits.")
    parser.add_argument("--analyze-limit", type=int, default=20,
                        help="Max calls to send to the LLM analyzer per run (default 20).")
    parser.add_argument("--max-per-day", type=int, default=0,
                        help="Max calls to transcribe per calendar day. 0 = unlimited (default).")
    parser.add_argument("--ignore-watermark", action="store_true",
                        help="Ignore local watermark and re-fetch from 2026-02-01 (use for recovery).")
    parser.add_argument("--skip-ingest", action="store_true",
                        help="Skip source DB ingestion — process only records already in raw_calls "
                             "(use for Mcube 2.0 BIDs where records arrive via Call Sync tab).")
    args = parser.parse_args()

    orch = Orchestrator(args.bid)
    orch.run(
        ingest_limit=args.ingest_limit,
        transcribe_limit=args.transcribe_limit,
        analyze_limit=args.analyze_limit,
        max_per_day=args.max_per_day,
        ignore_watermark=args.ignore_watermark,
        skip_ingest=args.skip_ingest,
    )
