import json
from typing import Optional, Callable, Awaitable
from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect

from services.log_utils import Log


class ConnectionState:
    """
    Tracks the current Mcube call session and manages connection state.
    
    - Holds the current stream ID and call ID for the active Mcube call session.
    - Provides methods to reset or clear state when a new call starts or an interruption occurs.
    
    Audio-related state is now managed by AudioService, so this class focuses only on connection/session state.
    """
    
    def __init__(self):
        self.stream_id: Optional[str] = None
        self.call_id: Optional[str] = None
    
    def reset_stream_state(self) -> None:
        """Reset state when a new stream starts."""
        # Audio-related state is now managed by AudioService
        pass
    
    def clear_response_state(self) -> None:
        """Clear response-related state during interruptions."""
        # Audio-related state is now managed by AudioService
        pass


class WebSocketConnectionManager:
    """
    Orchestrates WebSocket communication between Mcube and ElevenLabs for Service Type 4.
    
    - Establishes, maintains, and closes WebSocket connections to Mcube (FastAPI).
    - Routes incoming messages to appropriate event handlers for media, start, and playedStream events.
    - Sends and receives messages, manages connection state, and coordinates with AudioService for checkpoint/clear events.
    
    This is the main interface for real-time, bidirectional communication between Mcube and ElevenLabs.
    """
    
    def __init__(self, mcube_ws: WebSocket):
        self.mcube_ws = mcube_ws
        self.state = ConnectionState()
    
    async def send_to_mcube(self, message: dict) -> None:
        """Send a message to Mcube WebSocket."""
        try:
            # Check if WebSocket is still connected before sending
            if self.mcube_ws.client_state.name != 'CONNECTED':
                return
            await self.mcube_ws.send_json(message)
        except Exception as e:
            # Handle WebSocket errors gracefully
            error_str = str(e).lower()
            if "not connected" in error_str or "closed" in error_str or "accept" in error_str:
                # WebSocket disconnected, silently skip
                pass
            else:
                import traceback
                Log.warning(f"⚠️ Error sending message to Mcube: {e}")
                Log.warning(f"Message: {message}")
                Log.warning(f"Full traceback: {traceback.format_exc()}")
    
    async def receive_from_mcube(
        self, 
        on_media: Callable[[dict], Awaitable[None]],
        on_start: Callable[[dict], Awaitable[None]],  # Changed to pass full data
        on_played_stream: Callable[[str], Awaitable[None]]
    ) -> None:
        """
        Receive messages from Mcube and route them to appropriate handlers.
        
        Args:
            on_media: Handler for media events
            on_start: Handler for call start events (now receives full data)
            on_played_stream: Handler for playedStream events
        """
        try:
            # Check WebSocket state before iterating
            if self.mcube_ws.client_state.name != 'CONNECTED':
                return
            
            async for message in self.mcube_ws.iter_text():
                # Check state again inside loop (connection might close during iteration)
                if self.mcube_ws.client_state.name != 'CONNECTED':
                    break
                    
                data = json.loads(message)
                
                if data['event'] == 'media':
                    await on_media(data)
                elif data['event'] == 'start':
                    start_info = data.get('start', {})
                    stream_id = start_info.get('streamId')
                    call_id = start_info.get('callId')
                    self.state.stream_id = stream_id
                    self.state.call_id = call_id
                    
                    # Log the callId if available, otherwise log the incoming start payload for debugging
                    if call_id:
                        Log.event("Mcube Start", {
                            "streamId": stream_id,
                            "callId": call_id
                        })
                    else:
                        try:
                            Log.event("Mcube Start (no callId)", start_info)
                        except Exception:
                            Log.error("Mcube start payload (no callId) and failed to serialize start_info.")
                    
                    self.state.reset_stream_state()
                    # Pass the full data to the handler so it can extract business ID and other metadata
                    await on_start(data)
                elif data['event'] == 'playedStream':
                    name = data.get('name')
                    if name:
                        await on_played_stream(name)
                    
        except WebSocketDisconnect:
            Log.info("Mcube WebSocket disconnected (expected)")
        except Exception as e:
            # Handle "WebSocket is not connected" errors gracefully
            error_str = str(e).lower()
            if "not connected" in error_str or "accept" in error_str or "closed" in error_str:
                # WebSocket disconnected (expected), silently handle
                pass
            else:
                Log.error(f"Error in receive_from_mcube: {e}")
    
    async def send_checkpoint_to_mcube(self, name: str) -> None:
        """Send a checkpoint event to Mcube using AudioService."""
        if self.state.stream_id:
            # Import here to avoid circular imports
            from services.audio_service import AudioService
            audio_service = AudioService()
            checkpoint_event = audio_service.create_checkpoint_message(self.state.stream_id, name)
            await self.send_to_mcube(checkpoint_event)
    
    async def clear_mcube_audio(self) -> None:
        """Clear audio buffer in Mcube using AudioService."""
        if self.state.stream_id:
            # Import here to avoid circular imports
            from services.audio_service import AudioService
            audio_service = AudioService()
            clear_event = audio_service.create_clear_message(self.state.stream_id)
            await self.send_to_mcube(clear_event)

    async def close_mcube_connection(self, code: int = 1000, reason: Optional[str] = None) -> None:
        """Close the Mcube WebSocket connection gracefully.
        This ends the call on Mcube's side.
        """
        try:
            await self.mcube_ws.close(code=code, reason=reason or "normal closure")
            Log.info("Closed Mcube WebSocket connection")
        except Exception as e:
            Log.error(f"Failed to close Mcube WebSocket: {e}")
