#!/usr/bin/env python3
"""
Transcribe recent answered calls from {bid}_raw_calls using Sarvam
and store results in {bid}_sarvamresponse.
"""

import argparse
import json
import logging
import os
import time
import urllib3

import pymysql
import requests
from dotenv import load_dotenv

from azure_upload import upload_to_azure_blob
from db_config import DB_CONFIG

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

SUBSCRIPTION_KEY = os.getenv('SARVAM_SUBSCRIPTION_KEY')
AUDIO_DIR = 'storage/audio/'


def parse_diarized_transcript(result):
    try:
        diarized_data = result.get('diarized_transcript', {})
        entries = diarized_data.get('entries', [])

        if not entries:
            return {
                'full_transcript': result.get('transcript', ''),
                'speaker_segments': [],
                'num_speakers': 0,
                'duration': 0
            }

        speaker_segments = []
        speakers_set = set()

        for entry in entries:
            speaker_id = entry.get('speaker_id', 'unknown')
            speakers_set.add(speaker_id)
            speaker_name = speaker_id.replace('_', ' ').title()

            segment = {
                'speaker': speaker_name,
                'speaker_id': speaker_id,
                'text': entry.get('transcript', ''),
                'start_time': entry.get('start_time_seconds', 0),
                'end_time': entry.get('end_time_seconds', 0)
            }

            if len(speakers_set) <= 2:
                is_agent = 'speaker_0' in speaker_id.lower() or 'speaker_1' in speaker_id.lower()
                segment['role'] = 'agent' if is_agent else 'customer'
            else:
                segment['role'] = None

            speaker_segments.append(segment)

        duration = max([s['end_time'] for s in speaker_segments]) if speaker_segments else 0
        full_transcript = ' '.join([s['text'] for s in speaker_segments])

        logger.info(
            "✅ Parsed diarization: %s speakers, %s segments, %.1fs",
            len(speakers_set),
            len(speaker_segments),
            duration
        )

        return {
            'full_transcript': full_transcript or result.get('transcript', ''),
            'speaker_segments': speaker_segments,
            'num_speakers': len(speakers_set),
            'duration': round(duration, 2)
        }
    except Exception as e:
        logger.error("❌ Error parsing diarized transcript: %s", e)
        return {
            'full_transcript': result.get('transcript', ''),
            'speaker_segments': [],
            'num_speakers': 0,
            'duration': 0
        }


def init_sarvam_job():
    headers = {
        'API-Subscription-Key': SUBSCRIPTION_KEY,
        'Content-Type': 'application/json'
    }
    try:
        response = requests.post(
            'https://api.sarvam.ai/speech-to-text-translate/job/init',
            headers=headers,
            json={},
            verify=True
        )
        if response.status_code == 202:
            return response.json()
        logger.error("❌ Job init failed: %s - %s", response.status_code, response.text)
        return None
    except requests.exceptions.RequestException as e:
        logger.error("❌ Request exception during job init: %s", e)
        return None


def start_sarvam_job(job_id):
    headers = {
        'API-Subscription-Key': SUBSCRIPTION_KEY,
        'Content-Type': 'application/json'
    }
    data = {
        "job_id": job_id,
        "job_parameters": {"with_diarization": True}
    }
    try:
        response = requests.post(
            'https://api.sarvam.ai/speech-to-text-translate/job',
            headers=headers,
            json=data,
            verify=True
        )
        if response.status_code == 200:
            return True
        logger.error("❌ Failed to start Sarvam job: %s - %s", response.status_code, response.text)
        return False
    except requests.exceptions.RequestException as e:
        logger.error("❌ Request exception during job start: %s", e)
        return False


def poll_sarvam_status(job_id, output_url):
    headers = {'API-Subscription-Key': SUBSCRIPTION_KEY}
    for _ in range(120):
        time.sleep(10)
        try:
            response = requests.get(
                f'https://api.sarvam.ai/speech-to-text-translate/job/{job_id}/status',
                headers=headers,
                verify=True
            )
            if response.status_code != 200:
                logger.error("❌ Failed to check job status for job %s", job_id)
                continue

            data = response.json()
            if data.get('job_details') and len(data['job_details']) > 0:
                logger.info(
                    "Job %s state: %s - %s",
                    job_id,
                    data['job_state'],
                    data['job_details'][0]['state']
                )
                job_success = data['job_state'] == 'Completed' and data['job_details'][0]['state'] == 'Success'
            else:
                logger.info("Job %s state: %s", job_id, data['job_state'])
                job_success = data['job_state'] == 'Completed'

            if job_success:
                logger.info("✅ Job %s completed, fetching result...", job_id)
                time.sleep(10)
                if '?' in output_url:
                    base_url, query = output_url.split('?', 1)
                    final_url = f"{base_url.rstrip('/')}/0.json?{query}"
                else:
                    final_url = output_url.rstrip('/') + '/0.json'
                final_response = requests.get(final_url, verify=True)
                if final_response.status_code == 200 and final_response.text:
                    try:
                        return final_response.json()
                    except json.JSONDecodeError:
                        logger.error("❌ Failed to decode JSON for job %s", job_id)
                        return None
                logger.error("❌ Failed to fetch output for job %s", job_id)
                return None
        except requests.exceptions.RequestException as e:
            logger.error("❌ Exception while polling job %s: %s", job_id, e)
            return None

    logger.warning("❌ Timeout polling job %s", job_id)
    return None


def update_status(cursor, raw_table, call_id, status):
    cursor.execute(
        f"UPDATE `{raw_table}` SET transcription_status = %s WHERE callid = %s",
        (status, call_id)
    )


def process_single_call(call, bid, conn, cursor):
    call_id = call['callid']
    file_url = call.get('fileurl')
    raw_table = f"{bid}_raw_calls"
    response_table = f"{bid}_sarvamresponse"

    if not file_url:
        logger.error("❌ No fileurl for callid %s, skipping", call_id)
        update_status(cursor, raw_table, call_id, 'no_file')
        conn.commit()
        return False

    logger.info("📞 Processing call %s...", call_id)

    try:
        local_file = os.path.join(AUDIO_DIR, f'translate_{call_id}.wav')
        logger.info("⬇️  Downloading audio from %s", file_url)
        response = requests.get(file_url, timeout=30)

        if response.status_code != 200:
            logger.error("❌ Download failed for %s: %s", call_id, response.status_code)
            update_status(cursor, raw_table, call_id, 'download_failed')
            conn.commit()
            return False

        with open(local_file, 'wb') as f:
            f.write(response.content)

        job = init_sarvam_job()
        if not job:
            update_status(cursor, raw_table, call_id, 'job_init_failed')
            conn.commit()
            return False

        job_id = job['job_id']
        output_url = job['output_storage_path']
        input_path = job['input_storage_path']
        azure_url = input_path.split('?')[0].rstrip('/') + '/' + os.path.basename(local_file) + '?' + input_path.split('?')[1]

        logger.info("☁️  Uploading to Azure...")
        if not upload_to_azure_blob(azure_url, local_file):
            update_status(cursor, raw_table, call_id, 'upload_failed')
            conn.commit()
            return False

        logger.info("▶️  Starting Sarvam job %s...", job_id)
        if not start_sarvam_job(job_id):
            update_status(cursor, raw_table, call_id, 'job_start_failed')
            conn.commit()
            return False

        update_status(cursor, raw_table, call_id, 'processing')
        conn.commit()

        result = poll_sarvam_status(job_id, output_url)
        if not result:
            update_status(cursor, raw_table, call_id, 'processing_failed')
            conn.commit()
            return False

        diarized_data = parse_diarized_transcript(result)

        cursor.execute(f"""
            INSERT INTO `{response_table}`
            (callid, transcript, speaker_segments, num_speakers, duration, request_id, language, raw_response, status, created_at, updated_at)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
        """, (
            call_id,
            diarized_data['full_transcript'],
            json.dumps(diarized_data['speaker_segments']),
            diarized_data['num_speakers'],
            diarized_data['duration'],
            job_id,
            result.get('language_code', 'unknown'),
            str(result),
            1
        ))

        cursor.execute(
            f"UPDATE `{raw_table}` SET transcription_status = 'completed', transcription_requested = 1 WHERE callid = %s",
            (call_id,)
        )
        conn.commit()

        logger.info("✅ Successfully processed callid %s", call_id)

        try:
            os.remove(local_file)
        except Exception as e:
            logger.warning("⚠️ Could not delete file %s: %s", local_file, e)

        return True

    except Exception as e:
        logger.error("❌ Exception for callid %s: %s", call_id, e)
        update_status(cursor, raw_table, call_id, 'error')
        conn.commit()
        return False


def main():
    parser = argparse.ArgumentParser(description="Transcribe recent answered raw calls for a business.")
    parser.add_argument('--bid', required=True, help='Business ID (numeric)')
    parser.add_argument('--limit', type=int, default=50, help='Number of calls to process')
    parser.add_argument('--sleep', type=int, default=5, help='Seconds to wait between calls')
    args = parser.parse_args()

    if not SUBSCRIPTION_KEY:
        raise SystemExit("SARVAM_SUBSCRIPTION_KEY is not set in the environment")

    os.makedirs(AUDIO_DIR, exist_ok=True)

    bid = str(args.bid).strip()
    raw_table = f"{bid}_raw_calls"
    response_table = f"{bid}_sarvamresponse"

    conn = pymysql.connect(
        host=DB_CONFIG['host'],
        port=DB_CONFIG['port'],
        user=DB_CONFIG['user'],
        password=DB_CONFIG['password'],
        database=DB_CONFIG['database'],
        charset=DB_CONFIG['charset'],
        cursorclass=pymysql.cursors.DictCursor
    )
    cursor = conn.cursor()

    try:
        cursor.execute("SHOW TABLES LIKE %s", (raw_table,))
        if not cursor.fetchone():
            raise SystemExit(f"Table {raw_table} does not exist")
        cursor.execute("SHOW TABLES LIKE %s", (response_table,))
        if not cursor.fetchone():
            raise SystemExit(f"Table {response_table} does not exist")

        cursor.execute(f"""
            SELECT r.callid, r.fileurl, r.call_starttime
            FROM `{raw_table}` r
            LEFT JOIN `{response_table}` s ON r.callid = s.callid
            WHERE r.call_status = 'ANSWER'
            AND s.callid IS NULL
            AND r.fileurl IS NOT NULL
            AND r.fileurl <> ''
            ORDER BY r.call_starttime DESC
            LIMIT %s
        """, (args.limit,))
        calls = cursor.fetchall()

        logger.info("📊 Found %s answered calls to process (bid=%s)", len(calls), bid)

        if not calls:
            return

        successful = 0
        failed = 0

        for idx, call in enumerate(calls, 1):
            logger.info("➡️  Processing %s/%s (callid=%s)", idx, len(calls), call['callid'])
            if process_single_call(call, bid, conn, cursor):
                successful += 1
            else:
                failed += 1

            if idx < len(calls) and args.sleep > 0:
                time.sleep(args.sleep)

        logger.info("✅ Completed. Success: %s, Failed: %s", successful, failed)
    finally:
        cursor.close()
        conn.close()


if __name__ == '__main__':
    main()
