from __future__ import annotations

import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable
from dataclasses import dataclass
from types import TracebackType
from typing import Generic, Literal, TypeVar

from pydantic import BaseModel, ConfigDict, Field

from livekit import rtc

from ..types import NOT_GIVEN, NotGivenOr
from .chat_context import ChatContext, FunctionCall
from .tool_context import Tool, ToolChoice, ToolContext


@dataclass
class InputSpeechStartedEvent:
    pass


@dataclass
class InputSpeechStoppedEvent:
    user_transcription_enabled: bool


@dataclass
class MessageGeneration:
    message_id: str
    text_stream: AsyncIterable[str]  # could be io.TimedString
    audio_stream: AsyncIterable[rtc.AudioFrame]
    modalities: Awaitable[list[Literal["text", "audio"]]]


@dataclass
class GenerationCreatedEvent:
    message_stream: AsyncIterable[MessageGeneration]
    function_stream: AsyncIterable[FunctionCall]
    user_initiated: bool
    """True if the message was generated by the user using generate_reply()"""
    response_id: str | None = None
    """The response ID associated with this generation, used for metrics attribution"""


class RealtimeModelError(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    type: Literal["realtime_model_error"] = "realtime_model_error"
    timestamp: float
    label: str
    error: Exception = Field(..., exclude=True)
    recoverable: bool


@dataclass
class RealtimeCapabilities:
    message_truncation: bool
    turn_detection: bool
    user_transcription: bool
    auto_tool_reply_generation: bool
    audio_output: bool
    manual_function_calls: bool


class RealtimeError(Exception):
    def __init__(self, message: str) -> None:
        super().__init__(message)


class RealtimeModel:
    def __init__(self, *, capabilities: RealtimeCapabilities) -> None:
        self._capabilities = capabilities
        self._label = f"{type(self).__module__}.{type(self).__name__}"

    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "unknown"

    @property
    def capabilities(self) -> RealtimeCapabilities:
        return self._capabilities

    @property
    def label(self) -> str:
        return self._label

    @abstractmethod
    def session(self) -> RealtimeSession: ...

    @abstractmethod
    async def aclose(self) -> None: ...

    async def __aenter__(self) -> RealtimeModel:
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        await self.aclose()


EventTypes = Literal[
    "input_speech_started",  # serverside VAD (also used for interruptions)
    "input_speech_stopped",  # serverside VAD
    "input_audio_transcription_completed",
    "generation_created",
    "session_reconnected",
    "metrics_collected",
    "error",
]

TEvent = TypeVar("TEvent")


@dataclass
class InputTranscriptionCompleted:
    item_id: str
    """id of the item"""
    transcript: str
    """transcript of the input audio"""
    is_final: bool


@dataclass
class RealtimeSessionReconnectedEvent:
    pass


class RealtimeSession(ABC, rtc.EventEmitter[EventTypes | TEvent], Generic[TEvent]):
    def __init__(self, realtime_model: RealtimeModel) -> None:
        super().__init__()
        self._realtime_model = realtime_model

    @property
    def realtime_model(self) -> RealtimeModel:
        return self._realtime_model

    @property
    @abstractmethod
    def chat_ctx(self) -> ChatContext: ...

    @property
    @abstractmethod
    def tools(self) -> ToolContext: ...

    @abstractmethod
    async def update_instructions(self, instructions: str) -> None: ...

    @abstractmethod
    async def update_chat_ctx(
        self, chat_ctx: ChatContext
    ) -> None: ...  # can raise RealtimeError on Timeout

    @abstractmethod
    async def update_tools(self, tools: list[Tool]) -> None: ...

    @abstractmethod
    def update_options(self, *, tool_choice: NotGivenOr[ToolChoice | None] = NOT_GIVEN) -> None: ...

    @abstractmethod
    def push_audio(self, frame: rtc.AudioFrame) -> None: ...

    @abstractmethod
    def push_video(self, frame: rtc.VideoFrame) -> None: ...

    @abstractmethod
    def generate_reply(
        self,
        *,
        instructions: NotGivenOr[str] = NOT_GIVEN,
    ) -> asyncio.Future[GenerationCreatedEvent]: ...  # can raise RealtimeError on Timeout

    # commit the input audio buffer to the server
    @abstractmethod
    def commit_audio(self) -> None: ...

    # clear the input audio buffer to the server
    @abstractmethod
    def clear_audio(self) -> None: ...

    # cancel the current generation (do nothing if no generation is in progress)
    @abstractmethod
    def interrupt(self) -> None: ...

    # message_id is the ID of the message to truncate (inside the ChatCtx)
    @abstractmethod
    def truncate(
        self,
        *,
        message_id: str,
        modalities: list[Literal["text", "audio"]],
        audio_end_ms: int,
        audio_transcript: NotGivenOr[str] = NOT_GIVEN,
    ) -> None: ...

    @abstractmethod
    async def aclose(self) -> None: ...

    def start_user_activity(self) -> None:
        """notifies the model that user activity has started"""
        pass
