"""Probe WAV duration from recording URLs (no STT provider calls)."""

from __future__ import annotations

import os
import struct
import tempfile
import wave
from typing import Optional

import requests

_PROBE_RANGE_BYTES = max(4096, int(os.getenv("MIN_DURATION_PROBE_RANGE_BYTES", "65536")))
_PROBE_TIMEOUT_S = max(3, int(os.getenv("MIN_DURATION_PROBE_TIMEOUT_S", "20")))


def _wav_duration_from_header_bytes(data: bytes) -> Optional[float]:
    if len(data) < 44 or data[0:4] != b"RIFF" or data[8:12] != b"WAVE":
        return None
    offset = 12
    sample_rate = None
    byte_rate = None
    data_bytes = None
    while offset + 8 <= len(data):
        chunk_id = data[offset : offset + 4]
        chunk_size = struct.unpack("<I", data[offset + 4 : offset + 8])[0]
        chunk_start = offset + 8
        chunk_end = chunk_start + chunk_size
        if chunk_id == b"fmt " and chunk_end <= len(data) and chunk_size >= 16:
            sample_rate = struct.unpack("<I", data[chunk_start + 4 : chunk_start + 8])[0]
            byte_rate = struct.unpack("<I", data[chunk_start + 8 : chunk_start + 12])[0]
        elif chunk_id == b"data":
            data_bytes = chunk_size
            break
        offset = chunk_end + (chunk_size % 2)
    if sample_rate and sample_rate > 0 and data_bytes is not None:
        return data_bytes / float(sample_rate)
    if byte_rate and byte_rate > 0 and data_bytes is not None:
        return data_bytes / float(byte_rate)
    return None


def probe_wav_duration_from_url(url: str) -> Optional[float]:
    url = str(url or "").strip()
    if not url.startswith(("http://", "https://")):
        return None
    headers = {"Range": f"bytes=0-{_PROBE_RANGE_BYTES - 1}"}
    data = None
    try:
        resp = requests.get(url, headers=headers, timeout=_PROBE_TIMEOUT_S)
        if resp.status_code in (200, 206):
            data = resp.content
        elif resp.status_code == 416:
            resp = requests.get(url, timeout=_PROBE_TIMEOUT_S)
            if resp.status_code == 200:
                data = resp.content
    except Exception:
        return None
    if not data:
        return None

    tmp_path = None
    try:
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/tmp") as tmp:
            tmp.write(data)
            tmp_path = tmp.name
        with wave.open(tmp_path, "rb") as wf:
            rate = wf.getframerate()
            if rate > 0:
                return max(0.0, wf.getnframes() / float(rate))
    except Exception:
        pass
    finally:
        if tmp_path and os.path.exists(tmp_path):
            try:
                os.unlink(tmp_path)
            except OSError:
                pass
    header_duration = _wav_duration_from_header_bytes(data)
    return max(0.0, float(header_duration)) if header_duration is not None else None
