import base64
import io
from typing import Optional, Dict, Any
from dataclasses import dataclass
from config import Config
from services.log_utils import Log


@dataclass
class AudioMetadata:
    """
    Represents metadata for a single audio chunk, including timing, item, and stream information.
    Used to track and annotate audio data as it flows through the processing pipeline.
    """
    timestamp: Optional[int] = None
    item_id: Optional[str] = None
    stream_id: Optional[str] = None
    payload: Optional[str] = None
    format_type: Optional[str] = None


class AudioFormatConverter:
    """
    Converts audio payloads for Mcube format.
    Ensures compatibility and provides a single place to update format logic if requirements change.
    """
    
    # Audio format constants
    MCUBE_FORMAT = "audio/x-mulaw"
    

    
    @staticmethod
    def wav_to_mulaw(wav_bytes: bytes) -> bytes:
        """
        Convert WAV audio bytes to μ-law format for MCube.
        
        Args:
            wav_bytes: Raw WAV audio data
            
        Returns:
            μ-law encoded audio bytes
        """
        try:
            import wave
            import audioop
            
            # Read WAV data
            wav_io = wave.open(io.BytesIO(wav_bytes), 'rb')
            frames = wav_io.readframes(wav_io.getnframes())
            sample_rate = wav_io.getframerate()
            wav_io.close()
            
            # Convert to μ-law
            mulaw_data = audioop.lin2ulaw(frames, wav_io.getsampwidth())
            
            return mulaw_data
            
        except Exception as e:
            # If conversion fails, return original bytes
            return wav_bytes
    
    @staticmethod
    def validate_audio_payload(payload: str) -> bool:
        """
        Validate that an audio payload is properly formatted base64.
        
        Args:
            payload: Audio payload to validate
            
        Returns:
            True if payload is valid base64, False otherwise
        """
        try:
            base64.b64decode(payload)
            return True
        except Exception:
            return False


class AudioTimingManager:
    """
    Tracks and manages audio timing for responses and interruptions.
    Responsible for calculating durations, tracking the start of responses, and supporting precise interruption logic.
    """
    
    def __init__(self):
        self.current_timestamp: int = 0
        self.response_start_timestamp: Optional[int] = None
        self.last_item_id: Optional[str] = None
    
    def update_timestamp(self, timestamp: int) -> None:
        """Update the current audio timestamp."""
        self.current_timestamp = timestamp
    
    def start_response_tracking(self, item_id: str) -> None:
        """
        Start tracking a new response for timing calculations.
        
        Args:
            item_id: ID of the response item to track
        """
        self.response_start_timestamp = self.current_timestamp
        self.last_item_id = item_id
        
        if Config.SHOW_TIMING_MATH:
            print(f"Starting response tracking for item {item_id} at {self.current_timestamp}ms")
    
    def calculate_response_duration(self) -> Optional[int]:
        """
        Calculate the duration of the current response.
        
        Returns:
            Duration in milliseconds, or None if no response is being tracked
        """
        if self.response_start_timestamp is None:
            return None
        
        duration = self.current_timestamp - self.response_start_timestamp
        
        if Config.SHOW_TIMING_MATH:
            print(f"Response duration: {self.current_timestamp} - {self.response_start_timestamp} = {duration}ms")
        
        return duration
    
    def reset_response_tracking(self) -> None:
        """Reset response tracking state."""
        self.response_start_timestamp = None
        self.last_item_id = None
    
    def should_item_be_tracked(self, item_id: str) -> bool:
        """
        Determine if a new item should start being tracked.
        
        Args:
            item_id: ID of the item to check
            
        Returns:
            True if item should be tracked (is different from current)
        """
        return item_id != self.last_item_id
    
    def get_current_item_id(self) -> Optional[str]:
        """Get the ID of the currently tracked item."""
        return self.last_item_id


class AudioBufferManager:
    """
    Handles buffering of audio chunks and synchronization marks.
    Maintains queues for both audio data and marks, supporting smooth streaming and interruption handling.
    """
    
    def __init__(self):
        self.mark_queue: list = []
        self.audio_buffer: list = []
    
    def add_mark(self, mark_name: str = "responsePart") -> None:
        """
        Add a synchronization mark to the queue.
        
        Args:
            mark_name: Name of the mark for identification
        """
        self.mark_queue.append(mark_name)
    
    def remove_mark(self) -> Optional[str]:
        """
        Remove and return the oldest mark from the queue.
        
        Returns:
            The removed mark name, or None if queue is empty
        """
        return self.mark_queue.pop(0) if self.mark_queue else None
    
    def clear_marks(self) -> None:
        """Clear all marks from the queue."""
        self.mark_queue.clear()
    
    def has_pending_marks(self) -> bool:
        """Check if there are pending marks in the queue."""
        return len(self.mark_queue) > 0
    
    def add_audio_chunk(self, chunk: str, metadata: AudioMetadata) -> None:
        """
        Add an audio chunk to the buffer with metadata.
        
        Args:
            chunk: Audio data chunk
            metadata: Associated metadata
        """
        self.audio_buffer.append({
            'chunk': chunk,
            'metadata': metadata,
            'timestamp': metadata.timestamp
        })
    
    def clear_audio_buffer(self) -> None:
        """Clear the audio buffer."""
        self.audio_buffer.clear()
    
    def get_buffer_size(self) -> int:
        """Get the current size of the audio buffer."""
        return len(self.audio_buffer)


class AudioService:
    """
    Coordinates all audio processing operations for the application.
    Uses the format converter, timing manager, and buffer manager to process audio,
    manage synchronization, and handle interruptions for Service Type 4 (ElevenLabs).
    """
    
    def __init__(self):
        self.format_converter = AudioFormatConverter()
        self.timing_manager = AudioTimingManager()
        self.buffer_manager = AudioBufferManager()
    
    
    def process_raw_audio_bytes(self, audio_bytes: bytes, stream_id: str) -> Optional[Dict[str, Any]]:
        """
        Process raw audio bytes (e.g., from Sarvam TTS) for MCube.
        
        Args:
            audio_bytes: Raw audio data in bytes
            stream_id: Mcube stream identifier
            
        Returns:
            Processed audio message for Mcube, or None if invalid
        """
        if not audio_bytes:
            Log.warning(f"⚠️ No audio bytes provided for processing (stream_id: {stream_id})")
            return None
        
        if not stream_id:
            Log.warning(f"⚠️ No stream_id provided for audio processing (audio_bytes: {len(audio_bytes)} bytes)")
            return None
        
        try:
            # Convert audio to PCM using pydub (supports both WAV and MP3)
            from pydub import AudioSegment
            import io
            
            # Detect audio format and convert to PCM
            original_format = None
            try:
                # Try to load as MP3 first (ElevenLabs format)
                pcm_audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format="mp3")
                original_format = "mp3"
            except Exception as mp3_error:
                try:
                    # Fallback to WAV format (Sarvam format)
                    pcm_audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format="wav")
                    original_format = "wav"
                except Exception as wav_error:
                    # Last resort: try to detect format automatically
                    try:
                        pcm_audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
                        original_format = "auto-detected"
                    except Exception as auto_error:
                        Log.error(f"❌ Failed to load audio in any format. MP3 error: {mp3_error}, WAV error: {wav_error}, Auto error: {auto_error}")
                        return None
            
            # Convert to MCube format - use configured sample rate (8000 Hz μ-law)
            # System accepts both 8000 and 16000 Hz μ-law, configured via MCUBE_SAMPLE_RATE
            target_sample_rate = Config.MCUBE_SAMPLE_RATE
            pcm_audio = pcm_audio.set_frame_rate(target_sample_rate).set_channels(1).set_sample_width(2)
            pcm_bytes = pcm_audio.raw_data

            # Convert PCM to MuLaw with error handling
            import audioop
            mulaw_audio = audioop.lin2ulaw(pcm_bytes, 2)
            
            # Convert to base64 for MCube
            import time
            audio_base64 = base64.b64encode(mulaw_audio).decode('utf-8')
            
            # Create MCube audio message with proper format according to MCube specs
            audio_message = {
                "event": "playAudio",
                "media": {
                    "contentType": "audio/x-mulaw",
                    "sampleRate": Config.MCUBE_SAMPLE_RATE,
                    "payload": audio_base64,
                    "name": f"audio_{int(time.time() * 1000)}"  # Unique name for tracking
                }
            }
            
            return audio_message
            
        except Exception as e:
            import traceback
            Log.error(f"❌ Error processing raw audio bytes: {e}")
            Log.error(f"❌ Traceback: {traceback.format_exc()}")
            return None
    
   
    def create_checkpoint_message(self, stream_id: str, name: str) -> Dict[str, Any]:
        """
        Create a checkpoint message for audio synchronization.
        
        Args:
            stream_id: Mcube stream identifier
            name: Name of the audio segment
            
        Returns:
            Mcube checkpoint message
        """
        self.buffer_manager.add_mark(name)
        return {
            "event": "checkpoint",
            "streamId": stream_id,
            "name": name
        }
    
    def create_clear_message(self, stream_id: str) -> Dict[str, Any]:
        """
        Create a clear message to clear audio buffer.
        
        Args:
            stream_id: Mcube stream identifier
            
        Returns:
            Mcube clearAudio message
        """
        self.buffer_manager.clear_audio_buffer()
        return {
            "event": "clearAudio",
            "streamId": stream_id
        }
    
    def handle_played_stream_event(self, name: str) -> None:
        """Handle a playedStream event from Mcube."""
        removed_mark = self.buffer_manager.remove_mark()
        if Config.SHOW_TIMING_MATH and removed_mark:
            print(f"Processed playedStream: {name}")
    
    def calculate_interruption_timing(self) -> Optional[int]:
        """
        Calculate timing for audio interruption.
        
        Returns:
            Elapsed time for truncation, or None if no response is tracked
        """
        return self.timing_manager.calculate_response_duration()
    
    def should_handle_interruption(self) -> bool:
        """
        Determine if an interruption should be processed.
        
        Returns:
            True if interruption should be handled
        """
        return (self.timing_manager.last_item_id is not None and
                self.buffer_manager.has_pending_marks() and
                self.timing_manager.response_start_timestamp is not None)
    
    def reset_interruption_state(self) -> None:
        """Reset all interruption-related state."""
        self.timing_manager.reset_response_tracking()
        self.buffer_manager.clear_marks()
    
    def get_current_item_id(self) -> Optional[str]:
        """Get the ID of the currently tracked audio item."""
        return self.timing_manager.get_current_item_id()
    
    def _extract_mcube_payload(self, data: dict) -> Optional[str]:
        """Extract audio payload from Mcube data."""
        try:
            return data['media']['payload']
        except (KeyError, TypeError):
            return None
    
    def _extract_mcube_timestamp(self, data: dict) -> Optional[int]:
        """Extract timestamp from Mcube data."""
        try:
            return int(data['media']['timestamp'])
        except (KeyError, TypeError, ValueError):
            return None
