#!/usr/bin/env python3
"""
Keep N RabbitMQ STT workers running for the shared stt_jobs queue.

All pipeline-enabled BIDs (6004, 8329, 8398, …) publish to the same queue; workers
process jobs for any BID based on the job payload. Run via systemd or cron.

  python stt_worker_supervisor.py          # loop every STT_SUPERVISOR_INTERVAL_SEC
  python stt_worker_supervisor.py --once   # single sync
"""
from __future__ import annotations

import fcntl
import logging
import os
import signal
import subprocess
import sys
import time

from dotenv import load_dotenv

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(BASE_DIR)
STT_DIR = os.path.join(PROJECT_ROOT, "call-proccessing", "stt_pipeline")
SUPERVISOR_LOCK = os.getenv("STT_SUPERVISOR_LOCK", "/tmp/stt_worker_supervisor.lock")

load_dotenv(os.path.join(BASE_DIR, ".env"))
load_dotenv(os.path.join(STT_DIR, ".env"), override=False)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("stt_worker_supervisor")

STT_WORKER_COUNT = max(1, int(os.getenv("STT_WORKER_COUNT", "4")))
STT_WORKER_MAX = max(STT_WORKER_COUNT, int(os.getenv("STT_WORKER_MAX", "8")))
STT_QUEUE_SCALE_THRESHOLD = max(5, int(os.getenv("STT_QUEUE_SCALE_THRESHOLD", "15")))
STT_QUEUE_WORKERS_PER_MESSAGES = max(5, int(os.getenv("STT_QUEUE_WORKERS_PER_MESSAGES", "12")))
STT_SUPERVISOR_INTERVAL_SEC = max(10, int(os.getenv("STT_SUPERVISOR_INTERVAL_SEC", "60")))
STT_STUCK_HEAL_ENABLED = os.getenv("STT_STUCK_HEAL_ENABLED", "1").lower() not in (
    "0",
    "false",
    "no",
)
# Match STT worker processes (full path, relative cwd, legacy invocations).
WORKER_PGREP_PATTERNS = (
    os.path.join(STT_DIR, "run.py") + " --worker",
    "stt_pipeline/run.py --worker",
    "stt_pipeline/venv/bin/python run.py --worker",
    "stt_pipeline/venv/bin/python.*run.py --worker",
    r"stt_pipeline.*run\.py --worker",
)


def acquire_supervisor_lock() -> int:
    fd = os.open(SUPERVISOR_LOCK, os.O_CREAT | os.O_RDWR)
    try:
        fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
    except BlockingIOError:
        logger.error("Another STT supervisor is already running (%s); exiting", SUPERVISOR_LOCK)
        sys.exit(0)
    return fd


def worker_python() -> str:
    for candidate in (
        os.path.join(STT_DIR, "venv", "bin", "python"),
        os.path.join(BASE_DIR, "venv", "bin", "python"),
    ):
        if os.path.isfile(candidate):
            return candidate
    return "python3"


def list_worker_pids() -> list[int]:
    pids: set[int] = set()
    for pattern in WORKER_PGREP_PATTERNS:
        try:
            result = subprocess.run(
                ["pgrep", "-f", pattern],
                capture_output=True,
                text=True,
                timeout=10,
            )
            if result.returncode == 0:
                pids.update(int(x) for x in result.stdout.split() if x.strip().isdigit())
        except Exception:
            pass
    return sorted(pids)


def stop_worker(pid: int, *, force: bool = False) -> bool:
    try:
        os.kill(pid, signal.SIGKILL if force else signal.SIGTERM)
        return True
    except ProcessLookupError:
        return False
    except Exception as exc:
        logger.warning("Failed to stop worker pid=%s: %s", pid, exc)
        return False


def stop_all_workers(force: bool = False) -> list[int]:
    """Stop every STT worker PID we can find (including legacy relative-path workers)."""
    stopped: list[int] = []
    for pid in list_worker_pids():
        if stop_worker(pid, force=force):
            stopped.append(pid)
    if stopped and not force:
        time.sleep(2)
        survivors = list_worker_pids()
        if survivors:
            logger.warning(
                "STT workers still running after SIGTERM (%s); sending SIGKILL",
                survivors,
            )
            for pid in survivors:
                if stop_worker(pid, force=True):
                    stopped.append(pid)
            time.sleep(1)
    return stopped


def rabbitmq_queue_stats() -> dict:
    try:
        import pika

        host = os.getenv("RABBITMQ_HOST", "localhost")
        port = int(os.getenv("RABBITMQ_PORT", "5672"))
        user = os.getenv("RABBITMQ_USER", "guest")
        password = os.getenv("RABBITMQ_PASSWORD", "guest")
        queue = os.getenv("RABBITMQ_QUEUE", "stt_jobs")
        conn = pika.BlockingConnection(
            pika.ConnectionParameters(
                host=host,
                port=port,
                credentials=pika.PlainCredentials(user, password),
            )
        )
        ch = conn.channel()
        q = ch.queue_declare(queue=queue, durable=True, passive=True)
        stats = {
            "queue": queue,
            "messages": int(q.method.message_count),
            "consumers": int(q.method.consumer_count),
        }
        conn.close()
        return stats
    except Exception as exc:
        return {"error": str(exc)}


def start_worker(slot: int) -> dict:
    py = worker_python()
    log_path = f"/tmp/stt_worker_{slot}.log"
    env = os.environ.copy()
    env["PYTHONUNBUFFERED"] = "1"
    try:
        with open(log_path, "ab", buffering=0) as log_fh:
            proc = subprocess.Popen(
                [py, os.path.join(STT_DIR, "run.py"), "--worker"],
                cwd=STT_DIR,
                stdin=subprocess.DEVNULL,
                stdout=log_fh,
                stderr=subprocess.STDOUT,
                env=env,
                start_new_session=True,
                close_fds=True,
            )
        logger.info("Started STT worker slot=%s pid=%s log=%s", slot, proc.pid, log_path)
        return {"started": True, "pid": proc.pid, "log": log_path, "slot": slot}
    except Exception as exc:
        logger.error("Failed to start STT worker slot=%s: %s", slot, exc)
        return {"started": False, "error": str(exc), "slot": slot}


def _trim_workers_to(target: int) -> list[int]:
    """Kill excess STT workers until at most *target* remain."""
    stopped: list[int] = []
    for _ in range(5):
        pids = list_worker_pids()
        if len(pids) <= target:
            return stopped
        for pid in pids[target:]:
            if stop_worker(pid, force=True):
                stopped.append(pid)
        time.sleep(1)
    return stopped


def effective_worker_target(queue_stats: dict | None = None) -> int:
    """Scale workers up when stt_jobs backlog grows."""
    stats = queue_stats or rabbitmq_queue_stats()
    if stats.get("error"):
        return STT_WORKER_COUNT
    messages = int(stats.get("messages") or 0)
    if messages < STT_QUEUE_SCALE_THRESHOLD:
        return STT_WORKER_COUNT
    scaled = STT_WORKER_COUNT + max(0, messages // STT_QUEUE_WORKERS_PER_MESSAGES)
    return min(STT_WORKER_MAX, max(STT_WORKER_COUNT, scaled))


def heal_stuck_calls_if_enabled() -> dict:
    if not STT_STUCK_HEAL_ENABLED:
        return {"skipped": True}
    try:
        from config import Config
        from db_handler import DatabaseHandler
        from stt_queue_healer import heal_stuck_queued_calls

        class ConfigWrapper:
            def __init__(self, config):
                self._config = config

            def get(self, key, default=None):
                return getattr(self._config, key, default)

            def __getattr__(self, key):
                return getattr(self._config, key)

        db = DatabaseHandler(ConfigWrapper(Config()))
        return heal_stuck_queued_calls(db)
    except Exception as exc:
        logger.warning("STT stuck-call heal failed: %s", exc)
        return {"error": str(exc)}


def sync_workers(*, force_restart: bool = False) -> dict:
    queue_stats = rabbitmq_queue_stats()
    target = effective_worker_target(queue_stats)
    stopped: list[int] = []
    pids = list_worker_pids()
    if force_restart and pids:
        stopped = stop_all_workers(force=True)
        pids = []
    else:
        stopped.extend(_trim_workers_to(target))
        pids = list_worker_pids()
    started = []
    for slot in range(len(pids), target):
        result = start_worker(slot + 1)
        if result.get("started"):
            started.append(result)
        time.sleep(1)
    stopped.extend(_trim_workers_to(target))
    pids_after = list_worker_pids()
    if started and not force_restart:
        time.sleep(5)
    queue = rabbitmq_queue_stats()
    heal = heal_stuck_calls_if_enabled()
    consumers = int((queue or {}).get("consumers") or 0)
    if consumers < min(len(pids_after), target) and not force_restart:
        logger.warning(
            "RabbitMQ consumers (%s) < running workers (%s); restarting workers",
            consumers,
            len(pids_after),
        )
        return sync_workers(force_restart=True)
    return {
        "target_workers": target,
        "base_workers": STT_WORKER_COUNT,
        "max_workers": STT_WORKER_MAX,
        "running_before": len(pids) + len(stopped),
        "running_after": len(pids_after),
        "stopped": stopped,
        "started": started,
        "worker_pids": pids_after,
        "queue": queue,
        "heal": heal,
    }


def main():
    acquire_supervisor_lock()
    once = "--once" in sys.argv
    force_restart = "--restart" in sys.argv
    logger.info(
        "STT worker supervisor (%s%s); target_workers=%s python=%s stt_dir=%s",
        "once" if once else f"every {STT_SUPERVISOR_INTERVAL_SEC}s",
        ", restart-all" if force_restart else "",
        STT_WORKER_COUNT,
        worker_python(),
        STT_DIR,
    )
    while True:
        try:
            summary = sync_workers(force_restart=force_restart)
            logger.info(
                "Sync: running %s/%s workers (started %s); queue=%s",
                summary["running_after"],
                summary["target_workers"],
                len(summary["started"]),
                summary.get("queue"),
            )
        except Exception as exc:
            logger.exception("STT worker supervisor sync failed: %s", exc)
        if once:
            break
        time.sleep(STT_SUPERVISOR_INTERVAL_SEC)


if __name__ == "__main__":
    main()
