#!/usr/bin/env python3
"""
Ensure orchestrator_loop_{bid}.sh is running for every business with pipeline_enabled=1.
Run via systemd (mcube-orchestrator-supervisor.service) or cron every minute.
"""
from __future__ import annotations

import logging
import os
import subprocess
import sys
import time

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BASE_DIR)

from dotenv import load_dotenv

load_dotenv(os.path.join(BASE_DIR, ".env"))

from config import Config
from db_handler import DatabaseHandler

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("orchestrator_supervisor")

LOOP_INTERVAL_SEC = int(os.getenv("ORCHESTRATOR_LOOP_INTERVAL_SEC", "300"))
ORCHESTRATE_LIMIT = int(os.getenv("ORCHESTRATOR_RUN_LIMIT", "50"))
SUPERVISOR_INTERVAL_SEC = int(os.getenv("ORCHESTRATOR_SUPERVISOR_INTERVAL_SEC", "60"))


class ConfigWrapper:
    def __init__(self, config):
        self._config = config

    def get(self, key, default=None):
        return getattr(self._config, key, default)

    def __getattr__(self, key):
        return getattr(self._config, key)


def orchestrator_python() -> str:
    venv_py = os.path.join(BASE_DIR, "venv", "bin", "python")
    return venv_py if os.path.isfile(venv_py) else "python3"


def loop_script_path(bid: str) -> str:
    return os.path.join(BASE_DIR, f"orchestrator_loop_{bid}.sh")


def loop_is_running(bid: str) -> bool:
    pattern = f"orchestrator_loop_{bid}.sh"
    try:
        result = subprocess.run(
            ["pgrep", "-f", pattern],
            capture_output=True,
            text=True,
            timeout=5,
        )
        return bool(result.stdout.strip())
    except Exception:
        return False


def stop_loop(bid: str) -> bool:
    """Stop orchestrator_loop_{bid}.sh if running."""
    bid = str(bid).strip()
    if not bid:
        return False
    if not loop_is_running(bid):
        return False
    pattern = f"orchestrator_loop_{bid}.sh"
    try:
        subprocess.run(["pkill", "-f", pattern], timeout=10)
    except Exception as exc:
        logger.warning("[%s] stop_loop failed: %s", bid, exc)
        return False
    logger.info("Stopped orchestrator loop for BID %s", bid)
    return True


def ensure_loop_script(bid: str) -> str:
    """Create or refresh per-BID orchestrator loop script."""
    script = loop_script_path(bid)
    py = orchestrator_python()
    content = (
        "#!/usr/bin/env bash\n"
        "set -euo pipefail\n"
        f'cd "{BASE_DIR}"\n'
        "while true; do\n"
        f'  {py} orchestrate_pipeline.py --bid {bid} --limit {ORCHESTRATE_LIMIT}\n'
        f"  sleep {LOOP_INTERVAL_SEC}\n"
        "done\n"
    )
    with open(script, "w", encoding="utf-8") as fh:
        fh.write(content)
    os.chmod(script, 0o755)
    return script


def start_loop(bid: str) -> dict:
    bid = str(bid).strip()
    if not bid:
        return {"bid": bid, "started": False, "reason": "invalid_bid"}
    if loop_is_running(bid):
        return {"bid": bid, "started": False, "reason": "already_running"}
    script = ensure_loop_script(bid)
    log_path = f"/tmp/orch_{bid}.log"
    proc = subprocess.Popen(
        ["bash", "-lc", f'nohup bash "{script}" >>"{log_path}" 2>&1 & echo $!'],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        cwd=BASE_DIR,
    )
    out, err = proc.communicate(timeout=10)
    pid = None
    if out and out.strip():
        try:
            pid = int(out.strip().splitlines()[-1])
        except ValueError:
            pid = None
    if err:
        logger.warning("[%s] start stderr: %s", bid, err.strip())
    logger.info("Started orchestrator loop for BID %s (pid=%s)", bid, pid)
    return {"bid": bid, "started": True, "pid": pid, "log": log_path}


def sync_stt_workers() -> dict:
    """Keep STT_WORKER_COUNT RabbitMQ consumers running (shared stt_jobs queue)."""
    try:
        from stt_worker_supervisor import sync_workers

        return sync_workers()
    except Exception as exc:
        logger.warning("STT worker sync failed: %s", exc)
        return {"error": str(exc)}


def sync_analytics_workers() -> dict:
    """Keep ANALYTICS_WORKER_COUNT RabbitMQ consumers running (analytics_jobs queue)."""
    try:
        from analytics_worker_supervisor import sync_workers

        return sync_workers()
    except Exception as exc:
        logger.warning("Analytics worker sync failed: %s", exc)
        return {"error": str(exc)}


def sync_enabled_loops() -> dict:
    db = DatabaseHandler(ConfigWrapper(Config()))
    db.ensure_business_pipeline_config_table()
    enabled = {str(b).strip() for b in (db.get_enabled_pipeline_bids() or []) if str(b).strip()}
    started = []
    already = []
    stopped = []
    # Stop loops for businesses that are no longer pipeline_enabled
    for script in os.listdir(BASE_DIR):
        if not script.startswith("orchestrator_loop_") or not script.endswith(".sh"):
            continue
        bid = script[len("orchestrator_loop_") : -3]
        if not bid or bid in enabled:
            continue
        if stop_loop(bid):
            stopped.append(bid)
    for bid in sorted(enabled):
        if loop_is_running(bid):
            already.append(bid)
            continue
        result = start_loop(bid)
        if result.get("started"):
            started.append(bid)
    return {
        "enabled_bids": sorted(enabled),
        "started": started,
        "already_running": already,
        "stopped": stopped,
    }


def main():
    once = "--once" in sys.argv
    logger.info(
        "Orchestrator supervisor (%s); enabled loops use %s, limit=%s",
        "once" if once else f"every {SUPERVISOR_INTERVAL_SEC}s",
        orchestrator_python(),
        ORCHESTRATE_LIMIT,
    )
    while True:
        try:
            summary = sync_enabled_loops()
            stt = sync_stt_workers()
            analytics = sync_analytics_workers()
            logger.info(
                "Sync complete: %s enabled, %s started, %s already running, %s stopped",
                len(summary["enabled_bids"]),
                len(summary["started"]),
                len(summary["already_running"]),
                len(summary.get("stopped") or []),
            )
            if stt and not stt.get("error"):
                logger.info(
                    "STT workers: %s/%s running (base=%s max=%s); queue=%s heal=%s",
                    stt.get("running_after"),
                    stt.get("target_workers"),
                    stt.get("base_workers"),
                    stt.get("max_workers"),
                    stt.get("queue"),
                    (stt.get("heal") or {}).get("reset_total", 0),
                )
            if analytics and not analytics.get("error"):
                logger.info(
                    "Analytics workers: %s/%s running; queue=%s",
                    analytics.get("running_after"),
                    analytics.get("target_workers"),
                    analytics.get("queue"),
                )
        except Exception as exc:
            logger.exception("Supervisor sync failed: %s", exc)
        if once:
            break
        time.sleep(SUPERVISOR_INTERVAL_SEC)


if __name__ == "__main__":
    main()
