"""Probe WAV recording duration from a URL without calling STT providers."""

from __future__ import annotations

import logging
import os
import struct
import tempfile
import wave
from typing import Optional
from urllib.parse import urlparse

import requests

logger = logging.getLogger(__name__)

_PROBE_RANGE_BYTES = max(4096, int(os.getenv("MIN_DURATION_PROBE_RANGE_BYTES", "65536")))
_PROBE_TIMEOUT_S = max(3, int(os.getenv("MIN_DURATION_PROBE_TIMEOUT_S", "20")))
_PROBE_FULL_FALLBACK = os.getenv("MIN_DURATION_PROBE_FULL_FALLBACK", "1").lower() not in (
    "0",
    "false",
    "no",
)


def _wav_duration_from_path(path: str) -> Optional[float]:
    try:
        with wave.open(path, "rb") as wf:
            rate = wf.getframerate()
            if rate <= 0:
                return 0.0
            return wf.getnframes() / float(rate)
    except Exception:
        return _wav_duration_from_header_bytes(open(path, "rb").read(65536))


def _wav_duration_from_header_bytes(data: bytes) -> Optional[float]:
    """Parse standard PCM WAV header when wave module cannot read a partial file."""
    if len(data) < 44 or data[0:4] != b"RIFF" or data[8:12] != b"WAVE":
        return None
    offset = 12
    sample_rate = None
    byte_rate = None
    data_bytes = None
    while offset + 8 <= len(data):
        chunk_id = data[offset : offset + 4]
        chunk_size = struct.unpack("<I", data[offset + 4 : offset + 8])[0]
        chunk_start = offset + 8
        chunk_end = chunk_start + chunk_size
        if chunk_id == b"fmt " and chunk_end <= len(data) and chunk_size >= 16:
            sample_rate = struct.unpack("<I", data[chunk_start + 4 : chunk_start + 8])[0]
            byte_rate = struct.unpack("<I", data[chunk_start + 8 : chunk_start + 12])[0]
        elif chunk_id == b"data":
            data_bytes = chunk_size
            break
        offset = chunk_end + (chunk_size % 2)
    if sample_rate and sample_rate > 0 and data_bytes is not None:
        return data_bytes / float(sample_rate)
    if byte_rate and byte_rate > 0 and data_bytes is not None:
        return data_bytes / float(byte_rate)
    return None


def _fetch_recording_bytes(url: str) -> Optional[bytes]:
    url = str(url or "").strip()
    if not url.startswith(("http://", "https://")):
        return None
    headers = {"Range": f"bytes=0-{_PROBE_RANGE_BYTES - 1}"}
    try:
        resp = requests.get(url, headers=headers, timeout=_PROBE_TIMEOUT_S, stream=True)
        if resp.status_code in (200, 206):
            return resp.content
        if resp.status_code == 416 and _PROBE_FULL_FALLBACK:
            resp = requests.get(url, timeout=_PROBE_TIMEOUT_S)
            if resp.status_code == 200:
                return resp.content
    except Exception as exc:
        logger.debug("WAV probe fetch failed for %s: %s", url, exc)
    if _PROBE_FULL_FALLBACK:
        try:
            resp = requests.get(url, timeout=_PROBE_TIMEOUT_S)
            if resp.status_code == 200:
                return resp.content
        except Exception as exc:
            logger.debug("WAV probe full fetch failed for %s: %s", url, exc)
    return None


def probe_wav_duration_from_url(url: str) -> Optional[float]:
    """
    Return recording duration in seconds, or None when the URL is unreachable
    or the payload is not a readable WAV.
    """
    data = _fetch_recording_bytes(url)
    if not data:
        return None

    tmp_path = None
    try:
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/tmp") as tmp:
            tmp.write(data)
            tmp_path = tmp.name
        duration = _wav_duration_from_path(tmp_path)
        if duration is not None:
            return max(0.0, float(duration))
    except Exception:
        pass
    finally:
        if tmp_path and os.path.exists(tmp_path):
            try:
                os.unlink(tmp_path)
            except OSError:
                pass

    header_duration = _wav_duration_from_header_bytes(data)
    if header_duration is not None:
        return max(0.0, float(header_duration))
    return None


def recording_url_probe_ready(url: str) -> bool:
    """True when a URL looks like a fetchable http(s) recording."""
    parsed = urlparse(str(url or "").strip())
    return parsed.scheme in ("http", "https") and bool(parsed.netloc)
