import asyncio
import logging
import json
import os
from dataclasses import dataclass

import aio_pika


RABBITMQ_URL = os.getenv("RABBITMQ_URL", "amqp://guest:guest@127.0.0.1:5672/")

AI_UTTERANCES_QUEUE = os.getenv("AI_UTTERANCES_QUEUE", "ai.utterances")
TTS_QUEUE_PREFIX = os.getenv("TTS_QUEUE_PREFIX", "tts.")  # final queue name: f"{prefix}{call_id}"
CONTROL_QUEUE_PREFIX = os.getenv(
    "CONTROL_QUEUE_PREFIX", "mcube.control."
)  # final queue name: f"{prefix}{call_id}"

log = logging.getLogger("mcube.mq")


@dataclass(frozen=True)
class McubeMessage:
    body: dict

    def to_bytes(self) -> bytes:
        return json.dumps(self.body).encode("utf-8")


async def connect() -> aio_pika.RobustConnection:
    """
    Connect to RabbitMQ with retries.

    Local dev can start consumers before RabbitMQ is fully ready; in that case aio-pika
    may raise (e.g. "Server connection unexpectedly closed") and crash the whole process.
    Keeping a retry loop here prevents websocket disconnects caused by consumer shutdown.
    """
    base_s = float(os.getenv("RABBITMQ_CONNECT_RETRY_BASE_S", "1.0"))
    max_s = float(os.getenv("RABBITMQ_CONNECT_RETRY_MAX_S", "10.0"))
    attempt = 0
    while True:
        try:
            return await aio_pika.connect_robust(RABBITMQ_URL)
        except Exception as e:
            attempt += 1
            delay = min(max_s, base_s * (2 ** min(attempt, 6)))
            log.warning(
                "RabbitMQ connect failed (attempt=%s delay_s=%.1f): %s",
                attempt,
                delay,
                e,
            )
            await asyncio.sleep(delay)


async def get_channel(
    connection: aio_pika.RobustConnection,
    *,
    prefetch_count: int | None = None,
) -> aio_pika.abc.AbstractChannel:
    channel = await connection.channel()
    prefetch = (
        int(prefetch_count)
        if prefetch_count is not None
        else int(os.getenv("RABBITMQ_PREFETCH", "10"))
    )
    await channel.set_qos(prefetch_count=max(1, prefetch))
    return channel


async def declare_durable_queue(
    channel: aio_pika.abc.AbstractChannel, queue_name: str
) -> aio_pika.abc.AbstractQueue:
    return await channel.declare_queue(queue_name, durable=True)


def tts_queue_name(call_id: str) -> str:
    # Keep queue naming deterministic and short enough for broker limits.
    return f"{TTS_QUEUE_PREFIX}{call_id}"


def control_queue_name(call_id: str) -> str:
    # Keep queue naming deterministic and short enough for broker limits.
    return f"{CONTROL_QUEUE_PREFIX}{call_id}"


async def publish_json(
    channel: aio_pika.abc.AbstractChannel,
    queue_name: str,
    payload: dict,
) -> None:
    exchange = channel.default_exchange
    body = json.dumps(payload).encode("utf-8")
    message = aio_pika.Message(
        body=body,
        delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
        content_type="application/json",
    )
    await exchange.publish(message, routing_key=queue_name)


async def safe_get_message_text(msg: aio_pika.IncomingMessage) -> dict:
    # Guard against non-json messages.
    raw = msg.body.decode("utf-8", errors="replace")
    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        return {"_raw": raw}

