from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
from db.connection import get_connection

logger = logging.getLogger(__name__)

STATUS_PENDING = "pending"
STATUS_PROCESSING = "processing"
STATUS_DONE = "done"
STATUS_FAILED = "failed"

def insert_job(bid: str, call_id: str, recording_url: str, metadata: Optional[Dict] = None) -> Optional[int]:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "INSERT IGNORE INTO stt_jobs "
                "(bid, call_id, recording_url, status, metadata, created_at, updated_at) "
                "VALUES (%s, %s, %s, %s, %s, NOW(), NOW())",
                (bid, call_id, recording_url, STATUS_PENDING, json.dumps(metadata or {})),
            )
            if cur.rowcount > 0:
                return cur.lastrowid
            
            # If ignore triggered, we might want to fetch the existing ID
            cur.execute("SELECT id FROM stt_jobs WHERE bid=%s AND call_id=%s", (bid, call_id))
            row = cur.fetchone()
            return row["id"] if row else None

def mark_processing(job_id: int) -> None:
    _update_status(job_id, STATUS_PROCESSING)

def mark_done(job_id: int, transcript: str, speaker_segments: List[Dict],
              speaker_count: int, duration: float, stt_provider: str) -> None:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE stt_jobs SET status=%s, transcript=%s, speaker_segments=%s, "
                "speaker_count=%s, duration_seconds=%s, stt_provider=%s, "
                "error_message=NULL, updated_at=NOW() WHERE id=%s",
                (STATUS_DONE, transcript, json.dumps(speaker_segments),
                 speaker_count, duration, stt_provider, job_id),
            )

def mark_failed(job_id: int, error: str, increment_retry: bool = True) -> None:
    with get_connection() as conn:
        with conn.cursor() as cur:
            retry_sql = ", retry_count = retry_count + 1" if increment_retry else ""
            cur.execute(
                f"UPDATE stt_jobs SET status=%s, error_message=%s{retry_sql}, updated_at=NOW() WHERE id=%s",
                (STATUS_FAILED, error[:2000], job_id),
            )

def reset_stale_processing_jobs(stale_after_minutes: int = 30) -> int:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE stt_jobs SET status=%s, updated_at=NOW() "
                "WHERE status=%s AND updated_at < DATE_SUB(NOW(), INTERVAL %s MINUTE)",
                (STATUS_PENDING, STATUS_PROCESSING, stale_after_minutes),
            )
            return cur.rowcount

def get_all_seen_call_ids(bid: str) -> List[str]:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT call_id FROM stt_jobs WHERE bid=%s", (bid,))
            rows = cur.fetchall()
    return [r["call_id"] for r in rows]

def get_pending_jobs(bid: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
    with get_connection() as conn:
        with conn.cursor() as cur:
            if bid:
                cur.execute(
                    "SELECT * FROM stt_jobs WHERE status=%s AND bid=%s ORDER BY created_at ASC LIMIT %s",
                    (STATUS_PENDING, bid, limit),
                )
            else:
                cur.execute(
                    "SELECT * FROM stt_jobs WHERE status=%s ORDER BY created_at ASC LIMIT %s",
                    (STATUS_PENDING, limit),
                )
            return cur.fetchall()

def get_retryable_failed_jobs(max_retries: int, limit: int = 10) -> List[Dict[str, Any]]:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT * FROM stt_jobs WHERE status=%s AND retry_count<%s "
                "ORDER BY updated_at ASC LIMIT %s",
                (STATUS_FAILED, max_retries, limit),
            )
            return cur.fetchall()

def get_job_stats() -> Dict[str, int]:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT status, COUNT(*) as cnt FROM stt_jobs GROUP BY status")
            rows = cur.fetchall()
    return {r["status"]: r["cnt"] for r in rows}

def _update_status(job_id: int, status: str) -> None:
    with get_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE stt_jobs SET status=%s, updated_at=NOW() WHERE id=%s",
                (status, job_id),
            )
