#!/usr/bin/env python3
"""
Automated Pipeline for BID 6004 (MCUBE Sales)
Continuously syncs, transcribes, and analyzes calls.

Stage 1: SYNC - Pull new calls from source DB to destination DB
Stage 2: TRANSCRIBE - Use Sarvam speech_to_text_translate_job for English translation with diarization
Stage 3: ANALYZE - Use AWS Bedrock Nova for quality scoring, BANT, sentiment, talk-listen ratio

Usage:
    python3 pipeline_6004.py                  # Run once
    python3 pipeline_6004.py --continuous     # Run continuously (default interval: 120s)
    python3 pipeline_6004.py --continuous --interval 60  # Custom interval
"""
import sys
import os
import time
import json
import logging
import argparse
import tempfile
import signal
import pymysql
import requests
from datetime import datetime, timedelta
from pymysql.cursors import DictCursor

# Must run from dashboard-backend directory for imports
BACKEND_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BACKEND_DIR)
os.chdir(BACKEND_DIR)

from dotenv import load_dotenv
load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pipeline_6004.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('pipeline_6004')

# ─── Configuration ───────────────────────────────────────────────────────────

BID = '6004'

# Sarvam AI - using the translate job API for English output
SARVAM_API_KEY = (
    os.getenv('SARVAM_PIPELINE_KEY')
    or os.getenv('SARVAM_SUBSCRIPTION_KEY')
    or 'sk_o6cnn95f_gp7r6jV6LfihWL05Vj2CrVsw'
)
MAX_UPLOAD_RETRIES = 5

# Source DB (mcube call history)
SOURCE_DB = {
    'host': os.getenv('SYNC_SOURCE_DB_HOST', '127.0.0.1'),
    'user': os.getenv('SYNC_SOURCE_DB_USER', 'admin'),
    'password': os.getenv('SYNC_SOURCE_DB_PASSWORD', ''),
    'database': os.getenv('SYNC_SOURCE_DB_NAME', 'voicebot_cluster'),
    'charset': 'utf8mb4',
    'cursorclass': DictCursor
}

# Destination DB (voicebot_cluster)
DEST_DB = {
    'host': os.getenv('DB_HOST', '127.0.0.1'),
    'user': os.getenv('DB_USER', 'admin'),
    'password': os.getenv('DB_PASSWORD', ''),
    'database': os.getenv('DB_NAME', 'voicebot_cluster'),
    'charset': 'utf8mb4',
    'cursorclass': DictCursor
}

# Batch sizes
SYNC_BATCH = 500       # Max calls to fetch from source per run
TRANSCRIBE_BATCH = 1   # Max calls to transcribe per run
ANALYZE_BATCH = 1      # Max calls to analyze per run

# Only process answered calls above this duration (used in transcription/analysis stages)
MIN_CALL_DURATION_SECONDS = 120

# How many days back to look on the very first sync (empty raw_calls table)
SYNC_FIRST_RUN_LOOKBACK_DAYS = 90

# Graceful shutdown
shutdown_requested = False

def signal_handler(sig, frame):
    global shutdown_requested
    logger.info('Shutdown signal received, finishing current batch...')
    shutdown_requested = True

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)


# ─── Stage 1: SYNC ──────────────────────────────────────────────────────────

def _normalize_phone(phone):
    """Return a set of phone variants for a given number, matching db_handler logic."""
    digits = ''.join(ch for ch in str(phone or '') if ch.isdigit())
    if not digits:
        return set()
    core10 = digits[-10:] if len(digits) > 10 else digits
    variants = {digits, f'+{digits}', core10}
    if len(core10) == 10:
        variants.update({f'91{core10}', f'+91{core10}', f'0{core10}'})
    return {v for v in variants if v}


def _pick_customer_phone(call):
    """Extract the customer-side phone from a callhistory row."""
    direction = (call.get('direction') or 'inbound').lower()
    callto = str(call.get('callto') or '').strip()
    callfrom = str(call.get('callfrom') or '').strip()
    clicktocalldid = str(call.get('clicktocalldid') or '').strip()
    if direction == 'outbound' and callto:
        return callto
    if direction == 'inbound' and callfrom:
        return callfrom
    return clicktocalldid or callto or callfrom


def _load_lead_phone_set(dest_conn):
    """Load all phone variants for BID 6004 from crm_leads_cache into a Python set."""
    cursor = dest_conn.cursor()
    cursor.execute(
        "SELECT phone_primary, phone_variants FROM crm_leads_cache WHERE bid = %s AND provider = 'leadsquared'",
        (BID,)
    )
    rows = cursor.fetchall() or []
    phone_set = set()
    for row in rows:
        primary = str(row.get('phone_primary') or '').strip()
        if primary:
            phone_set.update(_normalize_phone(primary))
        try:
            variants = json.loads(row.get('phone_variants') or '[]')
        except Exception:
            variants = []
        for v in variants:
            v = str(v).strip()
            if v:
                phone_set.add(v)
    return phone_set


def _get_call_sync_watermark(dest_conn):
    """Return the watermark datetime for the call sync.

    Uses MAX(call_starttime) from 6004_raw_calls minus a 1-hour overlap buffer.
    Falls back to SYNC_FIRST_RUN_LOOKBACK_DAYS ago if the table is empty.
    """
    cursor = dest_conn.cursor()
    try:
        cursor.execute("SELECT MAX(call_starttime) AS latest FROM `%s_raw_calls`" % BID)
        row = cursor.fetchone()
        latest = row['latest'] if row else None
    except Exception:
        latest = None

    if latest is None:
        return datetime.now() - timedelta(days=SYNC_FIRST_RUN_LOOKBACK_DAYS)
    # 1-hour overlap to catch late-arriving records
    return latest - timedelta(hours=1)


def sync_calls():
    """Sync new ANSWER calls from source DB to destination DB.

    Imports ALL ANSWER calls regardless of CRM match.
    CRM data from crm_leads_cache is used for enrichment only (lead names / owner).
    Uses a watermark derived from the latest call already in raw_calls.
    """
    logger.info('=== STAGE 1: SYNC ===')

    try:
        source_conn = pymysql.connect(**SOURCE_DB)
        dest_conn = pymysql.connect(**DEST_DB)
    except Exception as e:
        logger.error('DB connection failed: %s', e)
        return 0

    try:
        # ── Load lead phone set for logging / enrichment (not used to filter) ─
        phone_set = _load_lead_phone_set(dest_conn)
        logger.info('Loaded %d phone variants for BID %s from crm_leads_cache (used for enrichment only)', len(phone_set), BID)

        # ── Determine watermark ───────────────────────────────────────────
        watermark = _get_call_sync_watermark(dest_conn)
        logger.info('Fetching ANSWER calls since %s', watermark.strftime('%Y-%m-%d %H:%M:%S'))

        # ── Fetch from source ─────────────────────────────────────────────
        source_cursor = source_conn.cursor()
        source_table_candidates = [
            f'{BID}_callhistory',
            f'{BID}_call_history',
            f'{BID}_callarchive',
            f'{BID}_call_archive',
        ]
        source_cursor.execute("SHOW TABLES")
        existing_tables = {list(r.values())[0] for r in source_cursor.fetchall()}
        source_tables = [t for t in source_table_candidates if t in existing_tables]

        if not source_tables:
            logger.error('No source call table found for BID %s in source DB', BID)
            return 0

        source_table = source_tables[0]
        logger.info('Using source table: %s', source_table)

        # Detect available columns (callto/callfrom may not exist in all schemas)
        source_cursor.execute(f"SHOW COLUMNS FROM `{source_table}`")
        available_cols = {r['Field'] for r in source_cursor.fetchall()}
        optional_cols = ['callto', 'callfrom', 'clicktocalldid', 'answeredtime']
        select_cols = ['callid', 'bid', 'agentname', 'groupname', 'starttime', 'endtime',
                       'dialstatus', 'direction', 'filename', 'emp_phone']
        select_cols += [c for c in optional_cols if c in available_cols]
        cols_sql = ', '.join(select_cols)

        source_cursor.execute(
            f"SELECT {cols_sql} FROM `{source_table}` "
            f"WHERE dialstatus = 'ANSWER' AND starttime > %s "
            f"ORDER BY starttime ASC LIMIT %s",
            (watermark, SYNC_BATCH)
        )
        calls = source_cursor.fetchall() or []

        if not calls:
            logger.info('No new ANSWER calls since watermark')
            return 0

        logger.info('Fetched %d ANSWER calls from source, filtering by lead phone...', len(calls))

        # ── Filter by lead phone & insert ─────────────────────────────────
        dest_cursor = dest_conn.cursor()
        insert_q = """
            INSERT INTO `%s_raw_calls`
            (bid, callid, fileurl, status, agentname, groupname, call_starttime, call_endtime,
             call_status, agent_callinfo, customer_callinfo, direction, duration_seconds,
             transcription_requested, transcription_status, selected_for_processing)
            VALUES (%%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s)
            ON DUPLICATE KEY UPDATE
            fileurl = VALUES(fileurl),
            agentname = VALUES(agentname),
            groupname = VALUES(groupname),
            call_starttime = VALUES(call_starttime),
            call_endtime = VALUES(call_endtime),
            call_status = VALUES(call_status),
            agent_callinfo = CASE WHEN VALUES(agent_callinfo) != '' THEN VALUES(agent_callinfo) ELSE agent_callinfo END,
            customer_callinfo = CASE WHEN VALUES(customer_callinfo) != '' THEN VALUES(customer_callinfo) ELSE customer_callinfo END,
            direction = VALUES(direction),
            duration_seconds = CASE WHEN VALUES(duration_seconds) IS NOT NULL AND VALUES(duration_seconds) > 0 THEN VALUES(duration_seconds) ELSE duration_seconds END
        """ % BID

        inserted = 0
        updated = 0
        crm_matched = 0

        for call in calls:
            customer_phone = _pick_customer_phone(call)
            phone_variants = _normalize_phone(customer_phone)
            if phone_variants.intersection(phone_set):
                crm_matched += 1

            direction = (call.get('direction') or 'inbound').lower()
            agent_phone = str(call.get('emp_phone') or call.get('callfrom') or '').strip()

            # Use answeredtime as authoritative duration; fall back to endtime-starttime
            answered_duration = call.get('answeredtime')
            if not answered_duration or int(answered_duration) <= 0:
                start_t = call.get('starttime')
                end_t = call.get('endtime')
                try:
                    answered_duration = int((end_t - start_t).total_seconds()) if start_t and end_t else None
                except Exception:
                    answered_duration = None

            dest_cursor.execute(insert_q, (
                str(call.get('bid') or BID),
                call['callid'],
                call.get('filename') or '',
                0,
                call.get('agentname') or '',
                call.get('groupname') or '',
                call['starttime'],
                call['endtime'],
                call.get('dialstatus') or '',
                agent_phone,
                customer_phone,
                direction,
                answered_duration,
                0,
                'pending',
                0,
            ))

            if dest_cursor.rowcount == 1:
                inserted += 1
            elif dest_cursor.rowcount == 2:
                updated += 1

        dest_conn.commit()
        logger.info(
            'Sync complete: %d inserted, %d updated out of %d fetched (%d had CRM match)',
            inserted, updated, len(calls), crm_matched
        )
        return inserted

    except Exception as e:
        logger.error('Sync error: %s', e, exc_info=True)
        try:
            dest_conn.rollback()
        except Exception:
            pass
        return 0
    finally:
        source_conn.close()
        dest_conn.close()


# ─── Stage 2: TRANSCRIBE ────────────────────────────────────────────────────

def get_pending_transcriptions():
    """Get calls that need transcription"""
    conn = pymysql.connect(**DEST_DB)
    try:
        cursor = conn.cursor()
        query = """
            SELECT r.callid, r.fileurl
            FROM %s_raw_calls r
            LEFT JOIN %s_sarvamresponse s ON r.callid = s.callid
            WHERE r.call_status = 'ANSWER'
              AND r.fileurl IS NOT NULL AND r.fileurl != ''
              AND s.callid IS NULL
              AND TIMESTAMPDIFF(SECOND, r.call_starttime, r.call_endtime) >= %d
            ORDER BY r.call_starttime DESC
            LIMIT %d
        """ % (BID, BID, MIN_CALL_DURATION_SECONDS, TRANSCRIBE_BATCH)
        cursor.execute(query)
        return cursor.fetchall()
    finally:
        conn.close()


def transcribe_call(audio_url, callid):
    """Transcribe a single call using Sarvam translate batch API with diarization"""
    from sarvamai import SarvamAI
    client = SarvamAI(api_subscription_key=SARVAM_API_KEY)

    tmp = None
    try:
        # Download audio
        audio_response = requests.get(audio_url, timeout=60)
        if audio_response.status_code != 200:
            logger.error('Failed to download audio for %s: HTTP %s', callid, audio_response.status_code)
            return None

        # Save to temp file
        tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False, dir='/tmp')
        tmp.write(audio_response.content)
        tmp.close()

        # Retry upload (Azure 403 bug in SDK)
        job = None
        for attempt in range(MAX_UPLOAD_RETRIES):
            try:
                job = client.speech_to_text_translate_job.create_job(
                    model='saaras:v2.5',
                    with_diarization=True,
                    num_speakers=2,
                    prompt='Translate all speech to English'
                )
                ok = job.upload_files([tmp.name])
                if ok:
                    break
            except RuntimeError as e:
                if '403' in str(e) and attempt < MAX_UPLOAD_RETRIES - 1:
                    logger.warning('Upload attempt %d failed for %s (403), retrying...', attempt + 1, callid)
                    time.sleep(1)
                    job = None
                else:
                    raise

        if not job:
            logger.error('All upload attempts failed for %s', callid)
            return None

        # Start and wait
        job.start()
        status = job.wait_until_complete(poll_interval=3, timeout=300)

        if not job.is_successful():
            logger.error('Job failed for %s', callid)
            return None

        # Extract output file from job_details
        output_file = None
        if status.job_details:
            for detail in status.job_details:
                if detail.state == 'Success' and detail.outputs:
                    output_file = detail.outputs[0].file_name
                    break
        if not output_file:
            output_file = '0.json'

        links = client.speech_to_text_translate_job.get_download_links(
            job_id=job.job_id,
            files=[output_file]
        )

        download_url = None
        if links.download_urls and output_file in links.download_urls:
            download_url = links.download_urls[output_file].file_url

        if not download_url:
            logger.error('No download URL for %s', callid)
            return None

        resp = requests.get(download_url, timeout=60)
        if resp.status_code != 200:
            return None

        return resp.json()

    except Exception as e:
        logger.error('Error transcribing %s: %s', callid, e)
        return None
    finally:
        if tmp and os.path.exists(tmp.name):
            os.unlink(tmp.name)


def save_transcript(callid, result):
    """Save Sarvam transcript to database"""
    conn = pymysql.connect(**DEST_DB)
    try:
        transcript_text = result.get('transcript', '')
        request_id = result.get('request_id', 'sarvam_batch_' + callid)
        raw_response = json.dumps(result)

        # Parse diarized transcript
        diarized = result.get('diarized_transcript', {})
        entries = diarized.get('entries', []) if isinstance(diarized, dict) else []

        speaker_segments = []
        for entry in entries:
            raw_id = str(entry.get('speaker_id', '0'))
            # Normalize speaker_id: "speaker_0" stays, "0" becomes "speaker_0"
            if raw_id.startswith('speaker_'):
                speaker_id = raw_id
                num = raw_id.replace('speaker_', '')
            else:
                speaker_id = 'speaker_' + raw_id
                num = raw_id
            speaker = 'Speaker ' + num  # "Speaker 0", "Speaker 1"
            text = entry.get('transcript', '')
            start = entry.get('start_time_seconds', 0)
            end = entry.get('end_time_seconds', 0)
            is_agent = speaker_id == 'speaker_0'
            if text:
                speaker_segments.append({
                    'speaker': speaker,
                    'speaker_id': speaker_id,
                    'text': text,
                    'start': start,
                    'end': end,
                    'start_time': start,
                    'end_time': end,
                    'role': 'agent' if is_agent else 'customer'
                })

        num_speakers = len(set(s['speaker'] for s in speaker_segments)) if speaker_segments else 2
        duration = speaker_segments[-1]['end'] if speaker_segments else 0

        cursor = conn.cursor()

        insert_q = """
            INSERT INTO %s_sarvamresponse
            (callid, transcript, speaker_segments, duration, num_speakers, request_id, raw_response, stt_provider, created_at)
            VALUES (%%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, NOW())
            ON DUPLICATE KEY UPDATE
            transcript = VALUES(transcript), speaker_segments = VALUES(speaker_segments),
            duration = VALUES(duration), num_speakers = VALUES(num_speakers),
            raw_response = VALUES(raw_response), updated_at = NOW()
        """ % BID

        cursor.execute(insert_q, (
            callid, transcript_text,
            json.dumps(speaker_segments) if speaker_segments else None,
            duration, num_speakers, request_id, raw_response, 'sarvam'
        ))

        update_q = "UPDATE %s_raw_calls SET transcription_status = 'completed', status = 1 WHERE callid = %%s" % BID
        cursor.execute(update_q, (callid,))
        conn.commit()
        return True
    except Exception as e:
        logger.error('Error saving transcript for %s: %s', callid, e)
        conn.rollback()
        return False
    finally:
        conn.close()


def transcribe_pending():
    """Transcribe all pending calls"""
    logger.info('=== STAGE 2: TRANSCRIBE ===')

    pending = get_pending_transcriptions()
    if not pending:
        logger.info('No eligible calls for transcription (requires ANSWER and >= %ds)', MIN_CALL_DURATION_SECONDS)
        return 0

    logger.info('Found %d calls to transcribe', len(pending))
    success = 0
    failed = 0

    for idx, call in enumerate(pending, 1):
        if shutdown_requested:
            logger.info('Shutdown requested, stopping transcription')
            break

        callid = call['callid']
        fileurl = call['fileurl']
        logger.info('[%d/%d] Transcribing %s', idx, len(pending), callid)

        result = transcribe_call(fileurl, callid)
        if result and result.get('transcript'):
            if save_transcript(callid, result):
                success += 1
                t = result.get('transcript', '')
                logger.info('  OK - %s...', t[:80])
            else:
                failed += 1
        else:
            failed += 1
            if result and not result.get('transcript'):
                logger.warning('  Empty transcript for %s', callid)
            # Mark as failed
            conn = pymysql.connect(**DEST_DB)
            cursor = conn.cursor()
            cursor.execute(
                "UPDATE %s_raw_calls SET transcription_status = 'failed' WHERE callid = %%s" % BID,
                (callid,)
            )
            conn.commit()
            conn.close()

        if idx < len(pending):
            time.sleep(1)

    logger.info('Transcription: %d success, %d failed out of %d', success, failed, len(pending))
    return success


# ─── Stage 3: ANALYZE ───────────────────────────────────────────────────────

def latest_call_ready_for_analysis():
    """Check if the latest call is eligible for analysis (answered, long enough, transcribed, not analyzed)."""
    conn = pymysql.connect(**DEST_DB)
    try:
        cursor = conn.cursor()
        query = """
            SELECT r.callid
            FROM %s_raw_calls r
            INNER JOIN %s_sarvamresponse s ON r.callid = s.callid
            LEFT JOIN %s_callanalytics a ON r.callid = a.callid
            WHERE r.callid = (
                SELECT callid
                FROM %s_raw_calls
                WHERE call_starttime IS NOT NULL
                ORDER BY call_starttime DESC
                LIMIT 1
            )
              AND r.call_status = 'ANSWER'
              AND s.transcript IS NOT NULL
              AND s.transcript != ''
              AND a.callid IS NULL
              AND TIMESTAMPDIFF(SECOND, r.call_starttime, r.call_endtime) >= %d
            LIMIT 1
        """ % (BID, BID, BID, BID, MIN_CALL_DURATION_SECONDS)
        cursor.execute(query)
        return cursor.fetchone() is not None
    finally:
        conn.close()

def analyze_pending():
    """Analyze transcribed calls that haven't been analyzed yet"""
    logger.info('=== STAGE 3: ANALYZE ===')

    if not latest_call_ready_for_analysis():
        logger.info('Latest call not ready for analysis (answered + >= %ds + transcribed + not analyzed)', MIN_CALL_DURATION_SECONDS)
        return 0

    from config import Config
    from analyze_calls_with_parameters import batch_analyze_calls

    cfg = Config()
    config_dict = {}
    for key in dir(cfg):
        if key.isupper():
            config_dict[key] = getattr(cfg, key)

    results = batch_analyze_calls(BID, config_dict, limit=ANALYZE_BATCH)

    success_count = len(results['success'])
    failed_count = len(results['failed'])

    if success_count > 0 or failed_count > 0:
        logger.info('Analysis: %d success, %d failed', success_count, failed_count)
        for f in results['failed']:
            logger.error('  FAILED: %s - %s', f['callid'], f['error'])
    else:
        logger.info('No calls need analysis')

    return success_count


# ─── Main Pipeline ───────────────────────────────────────────────────────────

def run_pipeline():
    """Run one iteration of the full pipeline"""
    logger.info('')
    logger.info('=' * 70)
    logger.info('  PIPELINE RUN - %s', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    logger.info('=' * 70)

    # Stage 1: Sync
    new_calls = sync_calls()

    if shutdown_requested:
        return

    # Stage 2: Transcribe
    transcribed = transcribe_pending()

    if shutdown_requested:
        return

    # Stage 3: Analyze
    analyzed = analyze_pending()

    logger.info('')
    logger.info('Pipeline summary: synced=%d, transcribed=%d, analyzed=%d', new_calls, transcribed, analyzed)
    logger.info('=' * 70)


def main():
    parser = argparse.ArgumentParser(description='Automated pipeline for BID 6004')
    parser.add_argument('--continuous', action='store_true', help='Run continuously')
    parser.add_argument('--interval', type=int, default=120, help='Interval between runs in seconds (default: 120)')
    parser.add_argument('--sync-only', action='store_true', help='Only run sync stage')
    parser.add_argument('--transcribe-only', action='store_true', help='Only run transcription stage')
    parser.add_argument('--analyze-only', action='store_true', help='Only run analysis stage')
    args = parser.parse_args()

    logger.info('Pipeline started for BID %s (continuous=%s, interval=%ds)', BID, args.continuous, args.interval)

    iteration = 0
    while True:
        iteration += 1

        if shutdown_requested:
            logger.info('Shutting down gracefully...')
            break

        try:
            if args.sync_only:
                sync_calls()
            elif args.transcribe_only:
                transcribe_pending()
            elif args.analyze_only:
                analyze_pending()
            else:
                run_pipeline()
        except Exception as e:
            logger.error('Pipeline error in iteration %d: %s', iteration, e, exc_info=True)

        if not args.continuous:
            break

        logger.info('Waiting %d seconds before next run...', args.interval)
        # Sleep in small increments so we can respond to shutdown signals
        for _ in range(args.interval):
            if shutdown_requested:
                break
            time.sleep(1)

    logger.info('Pipeline stopped.')


if __name__ == '__main__':
    main()
