#!/usr/bin/env python3
"""
Comprehensive Transcription Processor for All Branches
- Processes calls from raw_calls table
- Uses Sarvam AI for transcription
- Supports all branches
"""

import os
import sys
import requests
import logging
import time
import pymysql
import json
from pymysql.cursors import DictCursor
from dotenv import load_dotenv
import urllib3

# Add parent dir to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
load_dotenv()

from config import Config

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('comprehensive_transcription.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

SUBSCRIPTION_KEY = os.getenv('SARVAM_SUBSCRIPTION_KEY')
AUDIO_DIR = 'storage/audio/'

# Create audio directory if it doesn't exist
os.makedirs(AUDIO_DIR, exist_ok=True)

config = Config()

def get_db_connection():
    """Get database connection"""
    return pymysql.connect(
        host=config.DB_HOST,
        port=config.DB_PORT,
        user=config.DB_USER,
        password=config.DB_PASSWORD,
        database=config.DB_NAME,
        charset='utf8mb4',
        cursorclass=DictCursor,
        autocommit=False
    )


def parse_diarized_transcript(result):
    """Parse Sarvam API response with diarization"""
    try:
        diarized_data = result.get('diarized_transcript', {})
        entries = diarized_data.get('entries', [])

        if not entries:
            return {
                'full_transcript': result.get('transcript', ''),
                'speaker_segments': [],
                'num_speakers': 0,
                'duration': 0
            }

        speaker_segments = []
        speakers_set = set()

        for entry in entries:
            speaker_id = entry.get('speaker_id', 'unknown')
            speakers_set.add(speaker_id)

            speaker_name = speaker_id.replace('_', ' ').title()

            segment = {
                'speaker': speaker_name,
                'speaker_id': speaker_id,
                'text': entry.get('transcript', ''),
                'start_time': entry.get('start_time_seconds', 0),
                'end_time': entry.get('end_time_seconds', 0)
            }

            # Only assign roles if 2 or fewer speakers
            if len(speakers_set) <= 2:
                is_agent = 'speaker_0' in speaker_id.lower()
                segment['role'] = 'agent' if is_agent else 'customer'
            else:
                segment['role'] = None

            speaker_segments.append(segment)

        duration = max([s['end_time'] for s in speaker_segments]) if speaker_segments else 0
        full_transcript = ' '.join([s['text'] for s in speaker_segments])

        logger.info(f"✅ Parsed diarization: {len(speakers_set)} speakers, {len(speaker_segments)} segments, {duration:.1f}s")

        return {
            'full_transcript': full_transcript or result.get('transcript', ''),
            'speaker_segments': speaker_segments,
            'num_speakers': len(speakers_set),
            'duration': round(duration, 2)
        }

    except Exception as e:
        logger.error(f"❌ Error parsing diarized transcript: {e}")
        return {
            'full_transcript': result.get('transcript', ''),
            'speaker_segments': [],
            'num_speakers': 0,
            'duration': 0
        }


def init_sarvam_job():
    """Initialize Sarvam transcription job"""
    headers = {
        'API-Subscription-Key': SUBSCRIPTION_KEY,
        'Content-Type': 'application/json'
    }
    try:
        response = requests.post(
            'https://api.sarvam.ai/speech-to-text-translate/job/init',
            headers=headers,
            json={},
            verify=True
        )
        if response.status_code == 202:
            return response.json()
        else:
            logger.error(f"❌ Job init failed: {response.status_code} - {response.text}")
            return None
    except requests.exceptions.RequestException as e:
        logger.error(f"❌ Request exception during job init: {str(e)}")
        return None


def upload_to_azure_blob(azure_url, local_file):
    """Upload file to Azure blob storage"""
    try:
        with open(local_file, 'rb') as f:
            response = requests.put(
                azure_url,
                data=f,
                headers={'x-ms-blob-type': 'BlockBlob'},
                verify=True
            )
        if response.status_code in [200, 201]:
            logger.info(f"✅ File uploaded to Azure successfully")
            return True
        else:
            logger.error(f"❌ Azure upload failed: {response.status_code} - {response.text}")
            return False
    except Exception as e:
        logger.error(f"❌ Exception during Azure upload: {str(e)}")
        return False


def start_sarvam_job(job_id):
    """Start Sarvam transcription job"""
    headers = {
        'API-Subscription-Key': SUBSCRIPTION_KEY,
        'Content-Type': 'application/json'
    }
    data = {
        "job_id": job_id,
        "job_parameters": {"with_diarization": True}
    }
    try:
        response = requests.post(
            'https://api.sarvam.ai/speech-to-text-translate/job',
            headers=headers,
            json=data,
            verify=True
        )
        if response.status_code == 200:
            return True
        else:
            logger.error(f"❌ Failed to start Sarvam job: {response.status_code} - {response.text}")
            return False
    except requests.exceptions.RequestException as e:
        logger.error(f"❌ Request exception during job start: {str(e)}")
        return False


def poll_sarvam_status(job_id, output_url):
    """Poll Sarvam job status until complete"""
    headers = {'API-Subscription-Key': SUBSCRIPTION_KEY}
    for attempt in range(120):
        time.sleep(10)
        try:
            response = requests.get(
                f'https://api.sarvam.ai/speech-to-text-translate/job/{job_id}/status',
                headers=headers,
                verify=True
            )
            if response.status_code == 200:
                data = response.json()
                job_state = data.get('job_state')

                if data.get('job_details') and len(data['job_details']) > 0:
                    detail_state = data['job_details'][0]['state']
                    logger.info(f"Job {job_id} state: {job_state} - {detail_state}")
                    job_success = job_state == 'Completed' and detail_state == 'Success'
                else:
                    logger.info(f"Job {job_id} state: {job_state}")
                    job_success = job_state == 'Completed'

                if job_success:
                    logger.info(f"✅ Job {job_id} completed successfully, fetching result...")
                    time.sleep(10)

                    if '?' in output_url:
                        base_url, query = output_url.split('?', 1)
                        final_url = f"{base_url.rstrip('/')}/0.json?{query}"
                    else:
                        final_url = output_url.rstrip('/') + '/0.json'

                    final_response = requests.get(final_url, verify=True)

                    if final_response.status_code == 200 and final_response.text:
                        try:
                            parsed = final_response.json()
                            logger.info(f"✅ Got final result for job {job_id}")
                            return parsed
                        except json.JSONDecodeError:
                            logger.error(f"❌ Failed to decode JSON for job {job_id}")
                            return None
                    else:
                        logger.error(f"❌ Failed to get content for job {job_id}")
                        return None

        except requests.exceptions.RequestException as e:
            logger.error(f"❌ Exception while polling job {job_id}: {e}")

    logger.warning(f"❌ Timeout polling job {job_id}")
    return None


def get_calls_needing_transcription(bid, limit=20):
    """Get calls that need transcription from raw_calls table"""
    conn = get_db_connection()
    try:
        cursor = conn.cursor()

        query = f"""
            SELECT
                r.callid,
                r.bid,
                r.fileurl,
                r.agentname,
                r.groupname,
                r.call_starttime,
                r.direction,
                r.call_status,
                TIMESTAMPDIFF(SECOND, r.call_starttime, r.call_endtime) as duration_seconds
            FROM {bid}_raw_calls r
            LEFT JOIN {bid}_sarvamresponse s ON r.callid = s.callid
            WHERE
                r.fileurl IS NOT NULL
                AND r.fileurl != ''
                AND r.call_status = 'ANSWER'
                AND TIMESTAMPDIFF(SECOND, r.call_starttime, r.call_endtime) > 10
                AND s.callid IS NULL
            ORDER BY r.call_starttime DESC
            LIMIT %s
        """

        cursor.execute(query, (limit,))
        calls = cursor.fetchall()

        return calls

    except Exception as e:
        logger.error(f"Error fetching calls needing transcription: {e}")
        return []
    finally:
        conn.close()


def transcribe_call(call, bid):
    """Transcribe a single call"""
    callid = call['callid']
    file_url = call['fileurl']

    logger.info(f"\n{'=' * 60}")
    logger.info(f"📞 Transcribing Call: {callid}")
    logger.info(f"   Agent: {call['agentname']}")
    logger.info(f"   Location: {call['groupname']}")
    logger.info(f"   Time: {call['call_starttime']}")
    logger.info(f"{'=' * 60}")

    local_file = None

    try:
        # Download audio file
        logger.info(f"⬇️  Downloading audio file...")
        local_file = os.path.join(AUDIO_DIR, f'transcribe_{callid}.wav')
        response = requests.get(file_url)

        if response.status_code != 200:
            logger.error(f"❌ Failed to download file. Status code: {response.status_code}")
            return False

        with open(local_file, 'wb') as f:
            f.write(response.content)

        logger.info(f"✅ Downloaded {len(response.content)} bytes")

        # Initialize Sarvam job
        logger.info(f"🚀 Initializing Sarvam job...")
        job = init_sarvam_job()
        if not job:
            logger.error(f"❌ Failed to initialize Sarvam job")
            return False

        job_id = job['job_id']
        output_url = job['output_storage_path']

        # Build Azure upload URL
        input_path = job['input_storage_path']
        if '?' in input_path:
            base_url, query = input_path.split('?', 1)
            azure_url = f"{base_url.rstrip('/')}/{os.path.basename(local_file)}?{query}"
        else:
            azure_url = f"{input_path.rstrip('/')}/{os.path.basename(local_file)}"

        # Upload to Azure
        logger.info(f"☁️  Uploading to Azure...")
        if not upload_to_azure_blob(azure_url, local_file):
            logger.error(f"❌ Failed to upload to Azure")
            return False

        # Start transcription job
        logger.info(f"▶️  Starting transcription job...")
        if not start_sarvam_job(job_id):
            logger.error(f"❌ Failed to start Sarvam job")
            return False

        # Poll for result
        logger.info(f"⏳ Waiting for transcription to complete...")
        result = poll_sarvam_status(job_id, output_url)

        if not result:
            logger.error(f"❌ Failed to get transcription result")
            return False

        # Parse diarized transcript
        logger.info(f"📝 Parsing transcript...")
        diarized_data = parse_diarized_transcript(result)

        # Save to database
        logger.info(f"💾 Saving to database...")
        conn = get_db_connection()
        try:
            cursor = conn.cursor()

            cursor.execute(f"""
                INSERT INTO {bid}_sarvamresponse
                (callid, transcript, speaker_segments, num_speakers, duration, request_id, language, raw_response, status, created_at, updated_at)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
                ON DUPLICATE KEY UPDATE
                    transcript = VALUES(transcript),
                    speaker_segments = VALUES(speaker_segments),
                    num_speakers = VALUES(num_speakers),
                    duration = VALUES(duration),
                    request_id = VALUES(request_id),
                    language = VALUES(language),
                    raw_response = VALUES(raw_response),
                    status = VALUES(status),
                    updated_at = NOW()
            """, (
                callid,
                diarized_data['full_transcript'],
                json.dumps(diarized_data['speaker_segments']),
                diarized_data['num_speakers'],
                diarized_data['duration'],
                job_id,
                result.get('language_code', 'unknown'),
                json.dumps(result),
                1
            ))

            # Update transcription status in raw_calls
            cursor.execute(f"""
                UPDATE {bid}_raw_calls
                SET transcription_status = 2
                WHERE callid = %s
            """, (callid,))

            conn.commit()

            logger.info(f"✅ Successfully transcribed and saved {callid}")
            return True

        except Exception as e:
            logger.error(f"❌ Database error: {e}")
            conn.rollback()
            return False
        finally:
            conn.close()

    except Exception as e:
        logger.error(f"❌ Exception: {e}", exc_info=True)
        return False

    finally:
        # Clean up local file
        if local_file and os.path.exists(local_file):
            try:
                os.remove(local_file)
                logger.info(f"🧹 Cleaned up local file")
            except Exception as e:
                logger.warning(f"⚠️  Could not delete file: {e}")


def main():
    """Main function"""
    import argparse

    parser = argparse.ArgumentParser(description='Comprehensive Transcription Processor')
    parser.add_argument('--bid', type=str, default='7987', help='Business ID')
    parser.add_argument('--limit', type=int, default=20, help='Number of calls to process per run')
    parser.add_argument('--continuous', action='store_true', help='Run continuously')
    parser.add_argument('--interval', type=int, default=300, help='Interval between runs (seconds)')

    args = parser.parse_args()

    logger.info("\n" + "=" * 80)
    logger.info("   COMPREHENSIVE TRANSCRIPTION PROCESSOR")
    logger.info("=" * 80)
    logger.info(f"   Business ID: {args.bid}")
    logger.info(f"   Batch Size: {args.limit}")
    logger.info(f"   Continuous: {args.continuous}")
    logger.info("=" * 80 + "\n")

    if not SUBSCRIPTION_KEY:
        logger.error("❌ SARVAM_SUBSCRIPTION_KEY not found in environment variables")
        logger.error("   Please set it in your .env file")
        return 1

    iteration = 0

    while True:
        iteration += 1

        logger.info(f"\n{'#' * 80}")
        logger.info(f"# Iteration {iteration}")
        logger.info(f"{'#' * 80}\n")

        # Get calls needing transcription
        calls = get_calls_needing_transcription(args.bid, args.limit)

        if not calls:
            logger.info("✅ No calls need transcription at this time")
            if not args.continuous:
                break
            logger.info(f"⏳ Waiting {args.interval} seconds...")
            time.sleep(args.interval)
            continue

        logger.info(f"📋 Found {len(calls)} calls needing transcription\n")

        # Process each call
        success_count = 0
        failed_count = 0

        for i, call in enumerate(calls, 1):
            logger.info(f"\n[{i}/{len(calls)}]")

            if transcribe_call(call, args.bid):
                success_count += 1
            else:
                failed_count += 1

        # Summary
        logger.info("\n" + "=" * 80)
        logger.info(f"   BATCH SUMMARY")
        logger.info("=" * 80)
        logger.info(f"   Total: {len(calls)}")
        logger.info(f"   ✅ Success: {success_count}")
        logger.info(f"   ❌ Failed: {failed_count}")
        logger.info("=" * 80 + "\n")

        if not args.continuous:
            break

        logger.info(f"⏳ Waiting {args.interval} seconds before next batch...")
        time.sleep(args.interval)

    logger.info("\n✅ Transcription processor finished\n")
    return 0


if __name__ == '__main__':
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        logger.info("\n\n⚠️  Interrupted by user")
        sys.exit(1)
    except Exception as e:
        logger.error(f"\n\n❌ Fatal error: {e}", exc_info=True)
        sys.exit(1)
