from __future__ import annotations
import json
import logging
import time
import pika
from typing import Any, Dict, List, Optional
from config.settings import settings
from db import raw_calls, stt_jobs, bid_config

logger = logging.getLogger(__name__)

class RabbitMQJobProducer:
    def __init__(self):
        self.batch_size = settings.batch_size
        self.queue_name = settings.rabbitmq_queue
        
        # Ensure config table exists
        self._ensure_bid_config_table()
        # Auto-register discovered BIDs
        self._auto_register_bids()
        
        # RabbitMQ connection
        self.connection = pika.BlockingConnection(
            pika.ConnectionParameters(
                host=settings.rabbitmq_host,
                port=settings.rabbitmq_port,
                credentials=pika.PlainCredentials(
                    settings.rabbitmq_user,
                    settings.rabbitmq_password
                )
            )
        )
        self.channel = self.connection.channel()
        self.channel.queue_declare(queue=self.queue_name, durable=True)
        
        logger.info("RabbitMQ Producer initialized | queue=%s", self.queue_name)

    def _ensure_bid_config_table(self) -> None:
        try:
            bid_config.ensure_table()
        except Exception as exc:
            logger.warning("Could not ensure stt_pipeline_bid_config table: %s", exc)

    def _auto_register_bids(self) -> None:
        try:
            all_bids = raw_calls.get_all_bids()
            for bid in all_bids:
                bid_config.ensure_bid_registered(bid)
            if all_bids:
                logger.debug("Auto-registered %d bid(s) in stt_pipeline_bid_config", len(all_bids))
        except Exception as exc:
            logger.warning("Auto-register bids failed: %s", exc)

    def run_forever(self):
        logger.info("Producer started | poll=%ds | batch=%d", 
                    settings.poll_interval_seconds, self.batch_size)
        while True:
            try:
                self._run_one_cycle()
            except Exception as exc:
                logger.exception("Unexpected error in producer cycle: %s", exc)
            
            logger.info("Sleeping %ds...", settings.poll_interval_seconds)
            time.sleep(settings.poll_interval_seconds)

    def run_once(self):
        logger.info("Running producer cycle once...")
        self._run_one_cycle()

    def _run_one_cycle(self):
        active_configs = bid_config.get_enabled_bids()
        if not active_configs:
            # Fallback to whitelist if nothing enabled in DB
            if settings.bid_whitelist:
                active_configs = [
                    {"bid": b, "raw_calls_id_col": "id", "raw_calls_url_col": "recording_url",
                     "batch_size": self.batch_size}
                    for b in settings.bid_whitelist
                ]
            else:
                logger.debug("No enabled BIDs and no whitelist — nothing to do")
                return

        total_produced = 0
        for cfg in active_configs:
            bid = cfg["bid"]
            id_col = cfg.get("raw_calls_id_col") or "id"
            url_col = cfg.get("raw_calls_url_col") or "recording_url"
            batch = cfg.get("batch_size") or self.batch_size
            
            count = self._discover_and_publish(bid, id_col, url_col, batch)
            total_produced += count
            
        if total_produced > 0:
            logger.info("Cycle complete: Produced %d new job(s)", total_produced)
        else:
            logger.debug("Cycle complete: No new jobs found")

    def _discover_and_publish(self, bid: str, id_col: str, url_col: str, batch: int) -> int:
        already_seen = stt_jobs.get_all_seen_call_ids(bid)
        
        new_calls = raw_calls.get_new_calls(
            bid=bid, 
            already_seen_ids=already_seen, 
            limit=batch * 2,
            id_col=id_col,
            url_col=url_col
        )
        
        count = 0
        for call in new_calls:
            recording_url = call.get("recording_url")
            if not recording_url:
                continue
                
            call_id = call["call_id"]
            metadata = {k: str(v) for k, v in call.items() 
                        if k not in ("call_id", "recording_url") and v is not None}
            
            job_id = stt_jobs.insert_job(
                bid=bid, 
                call_id=call_id, 
                recording_url=recording_url, 
                metadata=metadata
            )
            
            if job_id:
                job_payload = {
                    "job_id": job_id,
                    "bid": bid,
                    "call_id": call_id,
                    "recording_url": recording_url,
                    "metadata": metadata
                }
                
                self.channel.basic_publish(
                    exchange='',
                    routing_key=self.queue_name,
                    body=json.dumps(job_payload),
                    properties=pika.BasicProperties(
                        delivery_mode=2,
                    )
                )
                count += 1
                logger.debug("[%s/%s] Published job_id=%d", bid, call_id, job_id)
                
        return count

    def close(self):
        if hasattr(self, 'connection') and self.connection and self.connection.is_open:
            self.connection.close()
