#!/usr/bin/env python3
"""
Continuous RAG transcript ingestion worker.

- Scans all *_sarvamresponse tables.
- Prioritizes configured business IDs (default includes 6004).
- Writes run progress to MySQL.
"""

import argparse
import json
import logging
import re
import time
from datetime import datetime
from typing import Dict, List

import pymysql
from pymysql.cursors import DictCursor

from config import Config
from rag_handler import RAGHandler


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("rag_ingestion_worker")


class RagIngestionWorker:
    def __init__(
        self,
        config: Dict,
        limit_per_bid: int = 2000,
        presales_only: bool = True,
        overwrite_existing: bool = False,
        interval_seconds: int = 300,
        priority_bids: List[str] | None = None,
        run_once: bool = False,
    ):
        self.config = config
        self.limit_per_bid = int(limit_per_bid)
        self.presales_only = bool(presales_only)
        self.overwrite_existing = bool(overwrite_existing)
        self.interval_seconds = int(interval_seconds)
        self.priority_bids = [str(b).strip() for b in (priority_bids or []) if str(b).strip()]
        self.run_once = bool(run_once)
        self.rag = RAGHandler(config)
        self.db_config = {
            "host": config.get("DB_HOST", "127.0.0.1"),
            "port": int(config.get("DB_PORT", 3306)),
            "user": config.get("DB_USER", "admin"),
            "password": config.get("DB_PASSWORD", ""),
            "database": config.get("DB_NAME", "voicebot_cluster"),
            "charset": "utf8mb4",
            "cursorclass": DictCursor,
            "autocommit": True,
        }
        self._ensure_progress_tables()

    def _conn(self):
        return pymysql.connect(**self.db_config)

    def _ensure_progress_tables(self):
        conn = self._conn()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_ingestion_progress (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    status ENUM('idle','running','success','error') DEFAULT 'idle',
                    last_run_started_at DATETIME NULL,
                    last_run_finished_at DATETIME NULL,
                    last_duration_ms BIGINT DEFAULT 0,
                    processed_calls INT DEFAULT 0,
                    ingested_documents INT DEFAULT 0,
                    ingested_chunks INT DEFAULT 0,
                    skipped INT DEFAULT 0,
                    last_error TEXT,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    UNIQUE KEY uniq_bid (bid),
                    INDEX idx_updated_at (updated_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_ingestion_runs (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    started_at DATETIME NOT NULL,
                    finished_at DATETIME NULL,
                    duration_ms BIGINT DEFAULT 0,
                    status ENUM('running','success','error') DEFAULT 'running',
                    processed_calls INT DEFAULT 0,
                    ingested_documents INT DEFAULT 0,
                    ingested_chunks INT DEFAULT 0,
                    skipped INT DEFAULT 0,
                    details JSON,
                    error TEXT,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    INDEX idx_bid_created (bid, created_at),
                    INDEX idx_status_created (status, created_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
        finally:
            conn.close()

    def _discover_bids(self) -> List[str]:
        conn = self._conn()
        try:
            cursor = conn.cursor()
            cursor.execute("SHOW TABLES LIKE %s", ("%_sarvamresponse",))
            rows = cursor.fetchall()
        finally:
            conn.close()
        bids = []
        for row in rows:
            table_name = list(row.values())[0]
            match = re.match(r"^([A-Za-z0-9_]+)_sarvamresponse$", table_name)
            if match:
                bids.append(match.group(1))
        bids = sorted(set(bids))
        return bids

    def _schedule_bids(self, bids: List[str]) -> List[str]:
        priority = [b for b in self.priority_bids if b in bids]
        rest = [b for b in bids if b not in priority]
        return priority + rest

    def _mark_running(self, bid: str, started_at: datetime) -> int:
        conn = self._conn()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                INSERT INTO rag_ingestion_progress (bid, status, last_run_started_at, last_error)
                VALUES (%s, 'running', %s, NULL)
                ON DUPLICATE KEY UPDATE
                    status='running',
                    last_run_started_at=VALUES(last_run_started_at),
                    last_error=NULL
                """,
                (str(bid), started_at),
            )
            cursor.execute(
                """
                INSERT INTO rag_ingestion_runs (bid, started_at, status)
                VALUES (%s, %s, 'running')
                """,
                (str(bid), started_at),
            )
            return int(cursor.lastrowid)
        finally:
            conn.close()

    def _mark_finished(self, bid: str, run_id: int, started_at: datetime, status: str, result: Dict, error: str | None = None):
        finished_at = datetime.utcnow()
        duration_ms = int((finished_at - started_at).total_seconds() * 1000)
        processed_calls = int(result.get("processed_calls", 0) or 0)
        ingested_docs = int(result.get("ingested_documents", 0) or 0)
        ingested_chunks = int(result.get("ingested_chunks", 0) or 0)
        skipped = int(result.get("skipped", 0) or 0)

        conn = self._conn()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                UPDATE rag_ingestion_progress
                SET status=%s,
                    last_run_finished_at=%s,
                    last_duration_ms=%s,
                    processed_calls=%s,
                    ingested_documents=%s,
                    ingested_chunks=%s,
                    skipped=%s,
                    last_error=%s
                WHERE bid=%s
                """,
                (
                    status,
                    finished_at,
                    duration_ms,
                    processed_calls,
                    ingested_docs,
                    ingested_chunks,
                    skipped,
                    error,
                    str(bid),
                ),
            )
            cursor.execute(
                """
                UPDATE rag_ingestion_runs
                SET finished_at=%s,
                    duration_ms=%s,
                    status=%s,
                    processed_calls=%s,
                    ingested_documents=%s,
                    ingested_chunks=%s,
                    skipped=%s,
                    details=%s,
                    error=%s
                WHERE id=%s
                """,
                (
                    finished_at,
                    duration_ms,
                    status,
                    processed_calls,
                    ingested_docs,
                    ingested_chunks,
                    skipped,
                    json.dumps(result or {}, ensure_ascii=True),
                    error,
                    int(run_id),
                ),
            )
        finally:
            conn.close()

    def run_iteration(self):
        bids = self._schedule_bids(self._discover_bids())
        if not bids:
            logger.info("No *_sarvamresponse tables found.")
            return
        logger.info("Discovered %s bids to ingest: %s", len(bids), ", ".join(bids))

        for bid in bids:
            started_at = datetime.utcnow()
            run_id = self._mark_running(bid, started_at)
            logger.info("Starting bid=%s (run_id=%s)", bid, run_id)
            try:
                result = self.rag.backfill_transcripts(
                    bid=bid,
                    presales_only=self.presales_only,
                    limit=self.limit_per_bid,
                    overwrite_existing=self.overwrite_existing,
                )
                self._mark_finished(bid, run_id, started_at, "success", result=result, error=None)
                logger.info(
                    "Completed bid=%s calls=%s docs=%s chunks=%s skipped=%s",
                    bid,
                    result.get("processed_calls", 0),
                    result.get("ingested_documents", 0),
                    result.get("ingested_chunks", 0),
                    result.get("skipped", 0),
                )
            except Exception as exc:
                err = str(exc)
                self._mark_finished(
                    bid,
                    run_id,
                    started_at,
                    "error",
                    result={"processed_calls": 0, "ingested_documents": 0, "ingested_chunks": 0, "skipped": 0},
                    error=err,
                )
                logger.exception("Ingestion failed for bid=%s: %s", bid, err)

    def run_forever(self):
        while True:
            self.run_iteration()
            if self.run_once:
                return
            logger.info("Sleeping %s seconds before next scan.", self.interval_seconds)
            time.sleep(self.interval_seconds)


def main():
    parser = argparse.ArgumentParser(description="Continuous transcript ingestion into RAG")
    parser.add_argument("--limit-per-bid", type=int, default=2000)
    parser.add_argument("--presales-only", action="store_true")
    parser.add_argument("--overwrite-existing", action="store_true")
    parser.add_argument("--interval-seconds", type=int, default=300)
    parser.add_argument("--priority-bids", default="", help="Comma separated bids to prioritize")
    parser.add_argument("--run-once", action="store_true")
    args = parser.parse_args()

    priority_bids = [b.strip() for b in str(args.priority_bids).split(",") if b.strip()]
    worker = RagIngestionWorker(
        config=Config.__dict__,
        limit_per_bid=args.limit_per_bid,
        presales_only=args.presales_only,
        overwrite_existing=args.overwrite_existing,
        interval_seconds=args.interval_seconds,
        priority_bids=priority_bids,
        run_once=args.run_once,
    )
    worker.run_forever()


if __name__ == "__main__":
    main()
