# ----------------------------------------------------------------------------
# Copyright (c) 2025, Gnaneshwar. All rights reserved.
# Purpose: Post Call Analysis
# Author: Gnaneshwar
# Description: This script performs post-call analysis, analyzing various metrics 
#              and insights from the call data. This can be used to evaluate agent 
#              performance and improve call center efficiency.
# ----------------------------------------------------------------------------

import os
import requests
import logging
import time
import pymysql
import json
from azure_upload import upload_to_azure_blob
from openai_helper import send_openai_analysis
from db_config import get_db_connection
from dotenv import load_dotenv
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
load_dotenv()

logging.basicConfig(level=logging.INFO)

SUBSCRIPTION_KEY = os.getenv('SARVAM_SUBSCRIPTION_KEY')
AUDIO_DIR = 'storage/audio/'


def parse_diarized_transcript(result):
    """
    Parse Sarvam API response with diarization

    Returns:
        dict with:
        - full_transcript: Complete transcript text
        - speaker_segments: List of segments with speaker, text, timestamps
        - num_speakers: Number of unique speakers
        - duration: Total duration in seconds
    """
    try:
        diarized_data = result.get('diarized_transcript', {})
        entries = diarized_data.get('entries', [])

        if not entries:
            # Fallback to regular transcript if no diarization
            return {
                'full_transcript': result.get('transcript', ''),
                'speaker_segments': [],
                'num_speakers': 0,
                'duration': 0
            }

        # Extract speaker segments
        speaker_segments = []
        speakers_set = set()

        for entry in entries:
            speaker_id = entry.get('speaker_id', 'unknown')
            speakers_set.add(speaker_id)

            # Format speaker name
            speaker_name = speaker_id.replace('_', ' ').title()

            segment = {
                'speaker': speaker_name,
                'speaker_id': speaker_id,
                'text': entry.get('transcript', ''),
                'start_time': entry.get('start_time_seconds', 0),
                'end_time': entry.get('end_time_seconds', 0)
            }

            # Only assign roles if 2 or fewer speakers
            # User requested: Don't assign roles for 3+ speakers (Sarvam does it wrong)
            if len(speakers_set) <= 2:
                is_agent = 'speaker_0' in speaker_id.lower() or 'speaker_1' in speaker_id.lower()
                segment['role'] = 'agent' if is_agent else 'customer'
            else:
                segment['role'] = None  # No role assignment for 3+ speakers

            speaker_segments.append(segment)

        # Calculate duration
        duration = max([s['end_time'] for s in speaker_segments]) if speaker_segments else 0

        # Build full transcript from segments
        full_transcript = ' '.join([s['text'] for s in speaker_segments])

        logging.info(f"✅ Parsed diarization: {len(speakers_set)} speakers, {len(speaker_segments)} segments, {duration:.1f}s")

        return {
            'full_transcript': full_transcript or result.get('transcript', ''),
            'speaker_segments': speaker_segments,
            'num_speakers': len(speakers_set),
            'duration': round(duration, 2)
        }

    except Exception as e:
        logging.error(f"❌ Error parsing diarized transcript: {e}")
        # Fallback to regular transcript
        return {
            'full_transcript': result.get('transcript', ''),
            'speaker_segments': [],
            'num_speakers': 0,
            'duration': 0
        }


def send_to_sarvam(bid=None):
    conn = get_db_connection()
    cursor = conn.cursor(pymysql.cursors.DictCursor)
    
    # Use dynamic table name if bid is provided, otherwise use default
    if bid:
        calls_table = f"{bid}_calls"
        sarvam_response_table = f"{bid}_sarvamresponse"
    else:
        calls_table = "7987_calls"  # Default fallback
        sarvam_response_table = "7987_sarvamresponse"
    
    cursor.execute(f"SELECT * FROM {calls_table} WHERE status = 0")
    calls = cursor.fetchall()
    for call in calls:
        try:
            call_id = call['callid']
            file_url = call['fileUrl']
            local_file = os.path.join(AUDIO_DIR, f'translate_{call_id}.wav')
            response = requests.get(file_url)
            if response.status_code != 200:
                logging.error(f"❌ Failed to download file for callid {call_id}. Status code: {response.status_code}")
                continue

            with open(local_file, 'wb') as f:
                f.write(response.content)

            if not os.path.exists(local_file):
                logging.error(f"❌ File missing after download for callid {call_id}")
                continue

            job = init_sarvam_job()
            if not job:
                logging.error(f"❌ Failed to initialize Sarvam job for callid {call_id}")
                continue

            job_id = job['job_id']
            output_url = job['output_storage_path']
            azure_url = job['input_storage_path'].split('?')[0].rstrip('/') + '/' + os.path.basename(local_file) + '?' + job['input_storage_path'].split('?')[1]

            if not upload_to_azure_blob(azure_url, local_file):
                logging.error(f"❌ Failed to upload file to Azure for callid {call_id}")
                continue

            if not start_sarvam_job(job_id):
                logging.error(f"❌ Failed to start Sarvam job for jobid {job_id}")
                continue

            result = poll_sarvam_status(job_id, output_url)
            if not result:
                logging.error(f"❌ Failed to get result for jobid {job_id}")
                continue

            # Parse diarized transcript
            diarized_data = parse_diarized_transcript(result)

            cursor.execute(f"""
                INSERT INTO {sarvam_response_table}
                (callid, transcript, speaker_segments, num_speakers, duration, request_id, language, raw_response, status, created_at, updated_at)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
            """, (
                call_id,
                diarized_data['full_transcript'],
                json.dumps(diarized_data['speaker_segments']),
                diarized_data['num_speakers'],
                diarized_data['duration'],
                job_id,
                result.get('language_code', 'unknown'),
                str(result),
                1
            ))
            conn.commit()

            logging.info(f"✅ Translated transcript inserted for callid {call_id}")
            send_openai_analysis(call_id, bid)

            try:
                os.remove(local_file)
                logging.info(f"🧹 Removed local file for callid {call_id}")
            except Exception as e:
                logging.warning(f"⚠️ Could not delete file {local_file}: {e}")

        except Exception as e:
            logging.error(f"❌ Exception for callid {call['callid']}: {str(e)}")

    cursor.close()
    conn.close()
    return {"message": "Translation batch completed"}

def init_sarvam_job():
    headers = {
        'API-Subscription-Key': SUBSCRIPTION_KEY,
        'Content-Type': 'application/json'
    }
    try:
        response = requests.post('https://api.sarvam.ai/speech-to-text-translate/job/init', headers=headers, json={}, verify=True)
        if response.status_code == 202:
            return response.json()
        else:
            logging.error(f"❌ Job init failed: {response.status_code} - {response.text}")
            return None
    except requests.exceptions.RequestException as e:
        logging.error(f"❌ Request exception during job init: {str(e)}")
        return None

def start_sarvam_job(job_id):
    headers = {
        'API-Subscription-Key': SUBSCRIPTION_KEY,
        'Content-Type': 'application/json'
    }
    data = {
        "job_id": job_id,
        "job_parameters": {"with_diarization": True}  # ENABLED: Speaker diarization
    }
    try:
        response = requests.post('https://api.sarvam.ai/speech-to-text-translate/job', headers=headers, json=data, verify=True)
        if response.status_code == 200:
            return True
        else:
            logging.error(f"❌ Failed to start Sarvam job: {response.status_code} - {response.text}")
            return False
    except requests.exceptions.RequestException as e:
        logging.error(f"❌ Request exception during job start: {str(e)}")
        return False

def poll_sarvam_status(job_id, output_url):
    headers = {'API-Subscription-Key': SUBSCRIPTION_KEY}
    for attempt in range(120):
        time.sleep(10)
        try:
            response = requests.get(f'https://api.sarvam.ai/speech-to-text-translate/job/{job_id}/status', headers=headers, verify=True)
            if response.status_code == 200:
                data = response.json()
                # Check if job_details exists and has elements before accessing
                if data.get('job_details') and len(data['job_details']) > 0:
                    logging.info(f"Job {job_id} state: {data['job_state']} - {data['job_details'][0]['state']}")
                    job_success = data['job_state'] == 'Completed' and data['job_details'][0]['state'] == 'Success'
                else:
                    logging.info(f"Job {job_id} state: {data['job_state']} - no job details available")
                    job_success = data['job_state'] == 'Completed'
                
                if job_success:
                    logging.info(f"✅ Job {job_id} completed successfully, fetching result...")
                    time.sleep(10)
                    if '?' in output_url:
                        base_url, query = output_url.split('?', 1)
                        final_url = f"{base_url.rstrip('/')}/0.json?{query}"
                    else:
                        final_url = output_url.rstrip('/') + '/0.json'
                    final_response = requests.get(final_url, verify=True)
                    logging.info(f"✅ Final response job {final_response}:{output_url}")
                    if final_response.status_code == 200:
                        if final_response.text:
                            try:
                                parsed = final_response.json()
                                logging.info(f"✅ Final JSON for job {job_id}: {parsed}")
                                return parsed
                            except json.JSONDecodeError:
                                logging.error(f"❌ Failed to decode JSON for job {job_id}. Raw response: {final_response.text}")
                                return None
                        else:
                            logging.error(f"❌ Empty response for job {job_id}. Response: {final_response.text}")
                            return None
                    else:
                        logging.error(f"❌ Failed to get content from output URL for job {job_id}. Response status: {final_response.status_code}")
                        return None
                else:
                    logging.info(f"❌ Job {job_id} not completed or failed, retrying... Current state: {data['job_state']}")
            else:
                logging.error(f"❌ Failed to check job status for job {job_id}. Status code: {response.status_code}")
        except requests.exceptions.RequestException as e:
            logging.error(f"❌ Exception while polling job status for job {job_id}: {e}")
            return None
    logging.warning(f"❌ Timeout polling job {job_id}")
    return None


def process_call_id(call_id, bid):
    conn = get_db_connection()
    cursor = conn.cursor(pymysql.cursors.DictCursor)

    # Use dynamic table name based on bid
    calls_table = f"{bid}_calls"
    sarvam_response_table = f"{bid}_sarvamresponse"

    # First, fetch the row by callid only so we can log the exact status/state
    cursor.execute(f"SELECT * FROM {calls_table} WHERE callid = %s", (call_id,))
    call = cursor.fetchone()
    if not call:
        logging.warning(f"⚠️ Call ID {call_id} not found in table {calls_table} (bid={bid}).")
        cursor.close()
        conn.close()
        return False

    # Now enforce status = 0 in Python so we get clearer diagnostics
    status = call.get("status") if isinstance(call, dict) else None
    if status != 0:
        logging.warning(
            f"⚠️ Call ID {call_id} found in table {calls_table} (bid={bid}) "
            f"but status={status}, expected 0. Skipping."
        )
        cursor.close()
        conn.close()
        return False

    try:
        file_url = call['fileUrl']
        local_file = os.path.join(AUDIO_DIR, f'translate_{call_id}.wav')
        response = requests.get(file_url)
        if response.status_code != 200:
            logging.error(f"❌ Failed to download file for callid {call_id}. Status code: {response.status_code}")
            return False

        with open(local_file, 'wb') as f:
            f.write(response.content)

        if not os.path.exists(local_file):
            logging.error(f"❌ File missing after download for callid {call_id}")
            return False

        job = init_sarvam_job()
        if not job:
            logging.error(f"❌ Failed to initialize Sarvam job for callid {call_id}")
            return False

        job_id = job['job_id']
        output_url = job['output_storage_path']
        azure_url = job['input_storage_path'].split('?')[0].rstrip('/') + '/' + os.path.basename(local_file) + '?' + job['input_storage_path'].split('?')[1]

        if not upload_to_azure_blob(azure_url, local_file):
            logging.error(f"❌ Failed to upload file to Azure for callid {call_id}")
            return False

        if not start_sarvam_job(job_id):
            logging.error(f"❌ Failed to start Sarvam job for jobid {job_id}")
            return False

        result = poll_sarvam_status(job_id, output_url)
        if not result:
            logging.error(f"❌ Failed to get result for jobid {job_id}")
            return False

        # Parse diarized transcript
        diarized_data = parse_diarized_transcript(result)

        cursor.execute(f"""
            INSERT INTO {sarvam_response_table}
            (callid, transcript, speaker_segments, num_speakers, duration, request_id, language, raw_response, status, created_at, updated_at)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
        """, (
            call_id,
            diarized_data['full_transcript'],
            json.dumps(diarized_data['speaker_segments']),
            diarized_data['num_speakers'],
            diarized_data['duration'],
            job_id,
            result.get('language_code', 'unknown'),
            str(result),
            1
        ))

        cursor.execute(f"UPDATE {calls_table} SET status = 1 WHERE callid = %s", (call_id,))
        conn.commit()

        logging.info(f"✅ Translated transcript inserted and status updated for callid {call_id}")
        send_openai_analysis(call_id, bid)

        # Optional clean-up
        if os.path.exists(local_file):
            os.remove(local_file)
            logging.info(f"🧹 Deleted local audio file for callid {call_id}")

        return True

    except Exception as e:
        logging.error(f"❌ Exception for callid {call_id}: {str(e)}")
        return False

    finally:
        cursor.close()
        conn.close()

# Keep init_sarvam_job, start_sarvam_job, poll_sarvam_status here exactly as you had before
