#!/usr/bin/env python3
"""
One-shot script: transcribe a specific call via Sarvam and save to 6004_sarvamresponse.
Usage: python3 run_single_transcription.py [callid]
"""
import sys
import os
import json
import time
import tempfile
import requests
import pymysql
from pymysql.cursors import DictCursor
from dotenv import load_dotenv

load_dotenv()

BID = '6004'
CALLID = sys.argv[1] if len(sys.argv) > 1 else '97393902541773212620'
SARVAM_API_KEY = os.getenv('SARVAM_PIPELINE_KEY') or os.getenv('SARVAM_SUBSCRIPTION_KEY')
MAX_UPLOAD_RETRIES = 5

DEST_DB = {
    'host': os.getenv('DB_HOST', '127.0.0.1'),
    'user': os.getenv('DB_USER', 'admin'),
    'password': os.getenv('DB_PASSWORD', ''),
    'database': os.getenv('DB_NAME', 'voicebot_cluster'),
    'charset': 'utf8mb4',
    'cursorclass': DictCursor,
    'autocommit': True
}


def get_call_url(callid):
    conn = pymysql.connect(**DEST_DB)
    try:
        cursor = conn.cursor()
        cursor.execute("SELECT fileurl FROM 6004_raw_calls WHERE callid = %s", (callid,))
        row = cursor.fetchone()
        return row['fileurl'] if row else None
    finally:
        conn.close()


def transcribe_call(audio_url, callid):
    from sarvamai import SarvamAI
    client = SarvamAI(api_subscription_key=SARVAM_API_KEY)
    tmp = None
    try:
        print(f"[1/4] Downloading audio from: {audio_url}")
        resp = requests.get(audio_url, timeout=120)
        print(f"      HTTP {resp.status_code}, size={len(resp.content)/1024:.1f} KB")
        if resp.status_code != 200:
            raise RuntimeError(f"Failed to download audio: HTTP {resp.status_code}")

        tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False, dir='/tmp')
        tmp.write(resp.content)
        tmp.close()
        print(f"      Saved to {tmp.name}")

        print("[2/4] Uploading to Sarvam (model=saaras:v2.5, diarization=True)...")
        job = None
        for attempt in range(MAX_UPLOAD_RETRIES):
            try:
                job = client.speech_to_text_translate_job.create_job(
                    model='saaras:v2.5',
                    with_diarization=True,
                    num_speakers=2,
                    prompt='Translate all speech to English'
                )
                ok = job.upload_files([tmp.name])
                if ok:
                    print(f"      Job ID: {job.job_id}  (attempt {attempt+1})")
                    break
                else:
                    job = None
            except RuntimeError as e:
                if '403' in str(e) and attempt < MAX_UPLOAD_RETRIES - 1:
                    print(f"      Attempt {attempt+1} failed (403), retrying...")
                    time.sleep(1)
                    job = None
                else:
                    raise

        if not job:
            raise RuntimeError("All upload attempts failed")

        print("[3/4] Waiting for job completion (this can take 2-10 minutes)...")
        job.start()
        status = job.wait_until_complete(poll_interval=5, timeout=600)
        print(f"      Job finished: success={job.is_successful()}")

        if not job.is_successful():
            raise RuntimeError("Sarvam job did not succeed")

        output_file = '0.json'
        if status.job_details:
            for detail in status.job_details:
                if detail.state == 'Success' and detail.outputs:
                    output_file = detail.outputs[0].file_name
                    break

        links = client.speech_to_text_translate_job.get_download_links(
            job_id=job.job_id, files=[output_file]
        )
        download_url = None
        if links.download_urls and output_file in links.download_urls:
            download_url = links.download_urls[output_file].file_url

        if not download_url:
            raise RuntimeError("No download URL returned")

        result_resp = requests.get(download_url, timeout=60)
        return result_resp.json()

    finally:
        if tmp and os.path.exists(tmp.name):
            os.unlink(tmp.name)


def save_transcript(callid, result):
    transcript_text = result.get('transcript', '')
    request_id = result.get('request_id', 'sarvam_batch_' + callid)
    raw_response = json.dumps(result)

    diarized = result.get('diarized_transcript', {})
    entries = diarized.get('entries', []) if isinstance(diarized, dict) else []

    speaker_segments = []
    for entry in entries:
        raw_id = str(entry.get('speaker_id', '0'))
        speaker_id = raw_id if raw_id.startswith('speaker_') else 'speaker_' + raw_id
        num = speaker_id.replace('speaker_', '')
        speaker = 'Speaker ' + num
        text = entry.get('transcript', '')
        start = entry.get('start_time_seconds', 0)
        end = entry.get('end_time_seconds', 0)
        if text:
            speaker_segments.append({
                'speaker': speaker, 'speaker_id': speaker_id, 'text': text,
                'start': start, 'end': end, 'start_time': start, 'end_time': end,
                'role': 'agent' if speaker_id == 'speaker_0' else 'customer'
            })

    num_speakers = len(set(s['speaker'] for s in speaker_segments)) if speaker_segments else 2
    duration = speaker_segments[-1]['end'] if speaker_segments else 0

    conn = pymysql.connect(**DEST_DB)
    try:
        cursor = conn.cursor()
        cursor.execute("""
            INSERT INTO 6004_sarvamresponse
            (callid, transcript, speaker_segments, duration, num_speakers, request_id, raw_response, stt_provider, created_at)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
            ON DUPLICATE KEY UPDATE
            transcript = VALUES(transcript), speaker_segments = VALUES(speaker_segments),
            duration = VALUES(duration), num_speakers = VALUES(num_speakers),
            raw_response = VALUES(raw_response), updated_at = NOW()
        """, (callid, transcript_text, json.dumps(speaker_segments), duration, num_speakers,
              request_id, raw_response, 'sarvam'))

        cursor.execute(
            "UPDATE 6004_raw_calls SET transcription_status = 'completed', status = 1 WHERE callid = %s",
            (callid,)
        )
        conn.commit()
        print(f"[4/4] Saved: {len(speaker_segments)} segments, duration={duration:.1f}s, transcript={len(transcript_text)} chars")
        return transcript_text, speaker_segments
    finally:
        conn.close()


if __name__ == '__main__':
    print(f"{'='*60}")
    print(f"  SARVAM TRANSCRIPTION — BID {BID}")
    print(f"  Call ID: {CALLID}")
    print(f"{'='*60}")

    audio_url = get_call_url(CALLID)
    if not audio_url:
        print(f"ERROR: No fileurl found for callid {CALLID}"); sys.exit(1)
    print(f"  File URL: {audio_url}\n")

    result = transcribe_call(audio_url, CALLID)
    if not result or not result.get('transcript'):
        print("ERROR: Transcription returned empty result"); sys.exit(1)

    # Save raw result
    with open(f'/tmp/sarvam_{CALLID}.json', 'w') as f:
        json.dump(result, f, indent=2)
    print(f"\n  Raw JSON saved to /tmp/sarvam_{CALLID}.json")

    transcript, segments = save_transcript(CALLID, result)

    print(f"\n{'='*60}")
    print("  TRANSCRIPT PREVIEW (first 2000 chars)")
    print(f"{'='*60}")
    print(transcript[:2000])
    if len(transcript) > 2000:
        print(f"\n  ... [{len(transcript)-2000} more chars]")
    print(f"\n  Total segments: {len(segments)}")
    print(f"  SUCCESS — Transcript saved to 6004_sarvamresponse")
