"""
ElevenLabs ConvAI Server
Provides WebSocket endpoint for real-time voice conversations with ElevenLabs
"""

import asyncio
import json
import os
import logging
from datetime import datetime
from typing import Dict, Set
from aiohttp import web, WSMsgType
import aiohttp

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("elevenlabs_server")

# Environment variables
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY")
ELEVENLABS_CONVAI_URL = "https://api.elevenlabs.io/v1/convai"

# In-memory storage for metrics and sessions
active_sessions: Dict[str, dict] = {}
session_metrics: Dict[str, list] = {}
metrics_history: list = []

# Cost tracking - ElevenLabs pricing (can be updated)
# Default pricing (per 1,000 characters for TTS, per minute for STT)
ELEVENLABS_PRICING = {
    "tts": {
        "eleven_monolingual": 0.03,      # $/1K characters
        "eleven_multilingual": 0.06,      # $/1K characters
    },
    "stt": {
        "ultra": 0.04,                    # $/minute
    },
    "llm": {
        # LLM pricing handled by ElevenLabs directly
    }
}

# Session cost tracking
session_costs: Dict[str, dict] = {}


class ElevenLabsConvAIHandler:
    """Handles WebSocket connections to ElevenLabs ConvAI API"""
    
    def __init__(self):
        self.active_connections: Set[web.WebSocketResponse] = set()
        self.session_count = 0
    
    async def handle_websocket(self, request: web.Request) -> web.WebSocketResponse:
        """Handle WebSocket connection for real-time voice conversation"""
        ws = web.WebSocketResponse()
        await ws.prepare(request)
        
        session_id = f"session_{self.session_count}"
        self.session_count += 1
        active_sessions[session_id] = {
            "started_at": datetime.now().isoformat(),
            "status": "connecting"
        }
        
        # Initialize cost tracking for this session
        session_costs[session_id] = {
            "stt_cost": 0,
            "tts_cost": 0,
            "llm_cost": 0,
            "total_cost": 0,
            "turn_count": 0,
            "stt_characters": 0,
            "tts_characters": 0,
            "llm_tokens": 0,
        }
        
        self.active_connections.add(ws)
        logger.info(f"New WebSocket connection: {session_id}")
        
        # Connect to ElevenLabs ConvAI WebSocket
        elevenlabs_ws_url = f"{ELEVENLABS_CONVAI_URL}/ws?api_key={ELEVENLABS_API_KEY}"
        
        try:
            async with aiohttp.ClientSession() as session:
                async with session.ws_connect(elevenlabs_ws_url) as elevenlabs_ws:
                    active_sessions[session_id]["status"] = "connected"
                    
                    # Forward messages between client and ElevenLabs
                    async def forward_to_elevenlabs():
                        async for msg in ws:
                            if msg.type == WSMsgType.TEXT:
                                text = msg.data
                                # Add session context
                                await elevenlabs_ws.send_str(text)
                            elif msg.type == WSMsgType.BINARY:
                                # Forward binary audio data
                                await elevenlabs_ws.send_bytes(msg.data)
                            elif msg.type == WSMsgType.ERROR:
                                break
                    
                    async def forward_to_client():
                        async for msg in elevenlabs_ws:
                            if msg.type == WSMsgType.TEXT:
                                # Parse and log the message
                                try:
                                    data = json.loads(msg.data)
                                    if session_id not in session_metrics:
                                        session_metrics[session_id] = []
                                    session_metrics[session_id].append({
                                        "timestamp": datetime.now().isoformat(),
                                        "type": data.get("type", "unknown"),
                                        "data": data
                                    })
                                    
                                    # Track costs from ElevenLabs response
                                    # ElevenLabs sends usage data in the response
                                    if data.get("type") == "usage" or "usage" in data:
                                        usage = data.get("usage", {})
                                        
                                        # STT cost (characters processed)
                                        stt_chars = usage.get("character_count", 0)
                                        if stt_chars > 0:
                                            session_costs[session_id]["stt_characters"] += stt_chars
                                            session_costs[session_id]["stt_cost"] += (stt_chars / 1000) * ELEVENLABS_PRICING["stt"]["ultra"]
                                        
                                        # TTS cost
                                        tts_chars = usage.get("tts_character_count", 0)
                                        if tts_chars > 0:
                                            session_costs[session_id]["tts_characters"] += tts_chars
                                            session_costs[session_id]["tts_cost"] += (tts_chars / 1000) * ELEVENLABS_PRICING["tts"]["eleven_multilingual"]
                                        
                                        # LLM cost (tokens)
                                        llm_tokens = usage.get("llm_token_count", 0)
                                        if llm_tokens > 0:
                                            session_costs[session_id]["llm_tokens"] += llm_tokens
                                            # LLM pricing is dynamic based on model
                                            
                                        # Update total
                                        session_costs[session_id]["total_cost"] = (
                                            session_costs[session_id]["stt_cost"] +
                                            session_costs[session_id]["tts_cost"] +
                                            session_costs[session_id]["llm_cost"]
                                        )
                                    
                                    # Track turn completion
                                    if data.get("type") == "agent_response" or data.get("type") == "turn_end":
                                        session_costs[session_id]["turn_count"] += 1
                                        
                                except:
                                    pass
                                await ws.send_str(msg.data)
                            elif msg.type == WSMsgType.BINARY:
                                await ws.send_bytes(msg.data)
                            elif msg.type == WSMsgType.ERROR:
                                break
                    
                    # Run both forwarding tasks concurrently
                    await asyncio.gather(
                        forward_to_elevenlabs(),
                        forward_to_client()
                    )
                    
        except Exception as e:
            logger.error(f"WebSocket error: {e}")
            active_sessions[session_id]["status"] = "error"
            active_sessions[session_id]["error"] = str(e)
        finally:
            self.active_connections.discard(ws)
            if session_id in active_sessions:
                active_sessions[session_id]["status"] = "disconnected"
                active_sessions[session_id]["ended_at"] = datetime.now().isoformat()
                # Add final cost info to session
                active_sessions[session_id]["costs"] = session_costs.get(session_id, {})
            logger.info(f"WebSocket closed: {session_id}")
        
        return ws


async def health_check(request: web.Request) -> web.Response:
    """Health check endpoint"""
    return web.json_response({
        "status": "ok",
        "timestamp": datetime.now().isoformat()
    })


async def metrics_handler(request: web.Request) -> web.Response:
    """Metrics endpoint showing active sessions, latency, and cost breakdown"""
    current_time = datetime.now()
    
    # Calculate average latency from metrics
    total_latency = 0
    latency_count = 0
    
    for session_id, metrics in session_metrics.items():
        for metric in metrics:
            if "latency" in metric.get("data", {}):
                total_latency += metric["data"]["latency"]
                latency_count += 1
    
    avg_latency = total_latency / latency_count if latency_count > 0 else 0
    
    # Calculate total costs across all sessions
    total_stt_cost = 0
    total_tts_cost = 0
    total_llm_cost = 0
    total_turns = 0
    
    for session_id, cost_data in session_costs.items():
        total_stt_cost += cost_data.get("stt_cost", 0)
        total_tts_cost += cost_data.get("tts_cost", 0)
        total_llm_cost += cost_data.get("llm_cost", 0)
        total_turns += cost_data.get("turn_count", 0)
    
    # Per-session cost breakdown
    session_cost_breakdown = {}
    for session_id, cost_data in session_costs.items():
        session_cost_breakdown[session_id] = {
            "stt_cost": round(cost_data.get("stt_cost", 0), 6),
            "tts_cost": round(cost_data.get("tts_cost", 0), 6),
            "llm_cost": round(cost_data.get("llm_cost", 0), 6),
            "total_cost": round(cost_data.get("total_cost", 0), 6),
            "turn_count": cost_data.get("turn_count", 0),
            "stt_characters": cost_data.get("stt_characters", 0),
            "tts_characters": cost_data.get("tts_characters", 0),
        }
    
    metrics_data = {
        "active_sessions": len([s for s in active_sessions.values() if s.get("status") == "connected"]),
        "total_sessions": len(active_sessions),
        "total_turns": total_turns,
        "average_end_to_end_latency_ms": round(avg_latency, 2),
        "timestamp": current_time.isoformat(),
        "cost_breakdown": {
            "total": {
                "stt": round(total_stt_cost, 6),
                "tts": round(total_tts_cost, 6),
                "llm": round(total_llm_cost, 6),
                "total": round(total_stt_cost + total_tts_cost + total_llm_cost, 6),
                "currency": "USD"
            },
            "per_turn_average": {
                "stt": round(total_stt_cost / total_turns, 6) if total_turns > 0 else 0,
                "tts": round(total_tts_cost / total_turns, 6) if total_turns > 0 else 0,
                "llm": round(total_llm_cost / total_turns, 6) if total_turns > 0 else 0,
                "total": round((total_stt_cost + total_tts_cost + total_llm_cost) / total_turns, 6) if total_turns > 0 else 0,
            }
        },
        "sessions": active_sessions,
        "session_costs": session_cost_breakdown
    }
    
    # Keep history (last 100 entries)
    metrics_history.append(metrics_data)
    if len(metrics_history) > 100:
        metrics_history.pop(0)
    
    return web.json_response(metrics_data)


async def list_llms(request: web.Request) -> web.Response:
    """List available LLMs from ElevenLabs"""
    url = f"{ELEVENLABS_CONVAI_URL}/llm/list"
    headers = {"xi-api-key": ELEVENLABS_API_KEY} if ELEVENLABS_API_KEY else {}
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                data = await response.json()
                return web.json_response(data)
    except Exception as e:
        logger.error(f"Error fetching LLM list: {e}")
        return web.json_response({"error": str(e)}, status=500)


async def calculate_llm_usage(request: web.Request) -> web.Response:
    """Calculate LLM usage/cost"""
    try:
        data = await request.json()
    except:
        data = {}
    
    url = f"{ELEVENLABS_CONVAI_URL}/llm-usage/calculate"
    headers = {"xi-api-key": ELEVENLABS_API_KEY} if ELEVENLABS_API_KEY else {}
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.post(url, json=data, headers=headers) as response:
                result = await response.json()
                return web.json_response(result)
    except Exception as e:
        logger.error(f"Error calculating LLM usage: {e}")
        return web.json_response({"error": str(e)}, status=500)


async def index_handler(request: web.Request) -> web.Response:
    """Index handler showing server info"""
    return web.json_response({
        "name": "ElevenLabs ConvAI Server",
        "version": "1.0.0",
        "endpoints": {
            "websocket": "/ws/",
            "health": "/health",
            "metrics": "/metrics",
            "list_llms": "/llm/list",
            "calculate_usage": "/llm-usage/calculate"
        },
        "status": "running"
    })


def create_app() -> web.Application:
    """Create and configure the aiohttp application"""
    app = web.Application()
    
    handler = ElevenLabsConvAIHandler()
    
    # Add routes
    app.router.add_get('/', index_handler)
    app.router.add_get('/health', health_check)
    app.router.add_get('/metrics', metrics_handler)
    app.router.add_get('/llm/list', list_llms)
    app.router.add_post('/llm-usage/calculate', calculate_llm_usage)
    app.router.add_get('/ws/', handler.handle_websocket)
    
    return app


if __name__ == '__main__':
    port = int(os.getenv('PORT', '8080'))
    logger.info(f"Starting ElevenLabs ConvAI Server on port {port}")
    app = create_app()
    web.run_app(app, host='0.0.0.0', port=port)
