"""Cross-process slot limiter for long-running Sarvam STT calls."""
from __future__ import annotations

import fcntl
import logging
import os
import time

logger = logging.getLogger(__name__)

_MAX_SLOTS = max(1, int(os.getenv("STT_SARVAM_MAX_CONCURRENT", "2")))
_LOCK_DIR = os.getenv("STT_SARVAM_LOCK_DIR", "/tmp/sarvam_stt_slots")
_WAIT_LOG_INTERVAL_S = 60
_MAX_SLOT_WAIT_S = max(60, int(os.getenv("STT_SARVAM_SLOT_WAIT_SECONDS", "900")))


class SarvamSlotTimeoutError(TimeoutError):
    """Raised when waiting for a Sarvam concurrency slot exceeds the limit."""


class SarvamConcurrencySlot:
    """Acquire one of N global Sarvam slots (shared across all STT workers)."""

    def __init__(self, callid: str):
        self.callid = callid
        self._fd: int | None = None
        self._slot: int | None = None

    def __enter__(self) -> "SarvamConcurrencySlot":
        os.makedirs(_LOCK_DIR, exist_ok=True)
        last_log = 0.0
        wait_started = time.time()
        while self._fd is None:
            if time.time() - wait_started >= _MAX_SLOT_WAIT_S:
                raise SarvamSlotTimeoutError(
                    f"[{self.callid}] Timed out after {_MAX_SLOT_WAIT_S}s waiting for Sarvam slot"
                )
            for slot in range(_MAX_SLOTS):
                path = os.path.join(_LOCK_DIR, f"slot_{slot}.lock")
                fd = os.open(path, os.O_CREAT | os.O_RDWR)
                try:
                    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
                except BlockingIOError:
                    os.close(fd)
                    continue
                self._fd = fd
                self._slot = slot
                logger.info("[%s] Acquired Sarvam slot %s/%s", self.callid, slot + 1, _MAX_SLOTS)
                return self
            now = time.time()
            if now - last_log >= _WAIT_LOG_INTERVAL_S:
                logger.info(
                    "[%s] Waiting for Sarvam slot (all %s busy)…",
                    self.callid,
                    _MAX_SLOTS,
                )
                last_log = now
            time.sleep(2)
        return self

    def __exit__(self, exc_type, exc, tb) -> None:
        if self._fd is not None:
            fcntl.flock(self._fd, fcntl.LOCK_UN)
            os.close(self._fd)
            logger.info("[%s] Released Sarvam slot %s", self.callid, (self._slot or 0) + 1)
            self._fd = None
