"""
Test script for ElevenLabs ConvAI Server
Tests WebSocket connection, latency, STT, LLM, TTS
"""

import asyncio
import aiohttp
import time
import json
import os
from datetime import datetime

# Configuration
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8080")
WS_URL = os.getenv("WS_URL", "ws://localhost:8080")
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY", "")


async def test_health_check():
    """Test health check endpoint"""
    print("\n" + "="*60)
    print("TEST 1: Health Check")
    print("="*60)
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(f"{SERVER_URL}/health") as response:
                data = await response.json()
                print(f"Status: {response.status}")
                print(f"Response: {data}")
                
                if data.get("status") == "ok":
                    print("✓ Health check PASSED")
                    return True
                else:
                    print("✗ Health check FAILED")
                    return False
    except Exception as e:
        print(f"✗ Health check FAILED: {e}")
        return False


async def test_metrics():
    """Test metrics endpoint"""
    print("\n" + "="*60)
    print("TEST 2: Metrics Endpoint")
    print("="*60)
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(f"{SERVER_URL}/metrics") as response:
                data = await response.json()
                print(f"Status: {response.status}")
                print(f"Active Sessions: {data.get('active_sessions')}")
                print(f"Total Sessions: {data.get('total_sessions')}")
                print(f"Total Turns: {data.get('total_turns')}")
                print(f"Avg Latency: {data.get('average_end_to_end_latency_ms')}ms")
                
                cost_breakdown = data.get("cost_breakdown", {})
                if cost_breakdown:
                    print("\nCost Breakdown:")
                    total = cost_breakdown.get("total", {})
                    print(f"  STT: ${total.get('stt', 0):.6f}")
                    print(f"  TTS: ${total.get('tts', 0):.6f}")
                    print(f"  LLM: ${total.get('llm', 0):.6f}")
                    print(f"  Total: ${total.get('total', 0):.6f}")
                    
                    per_turn = cost_breakdown.get("per_turn_average", {})
                    print("\nPer Turn Average:")
                    print(f"  STT: ${per_turn.get('stt', 0):.6f}")
                    print(f"  TTS: ${per_turn.get('tts', 0):.6f}")
                    print(f"  LLM: ${per_turn.get('llm', 0):.6f}")
                    print(f"  Total: ${per_turn.get('total', 0):.6f}")
                
                print("\n✓ Metrics check PASSED")
                return True
    except Exception as e:
        print(f"✗ Metrics check FAILED: {e}")
        return False


async def test_llm_list():
    """Test LLM list endpoint"""
    print("\n" + "="*60)
    print("TEST 3: LLM List")
    print("="*60)
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(f"{SERVER_URL}/llm/list") as response:
                data = await response.json()
                print(f"Status: {response.status}")
                print(f"Response: {json.dumps(data, indent=2)}")
                print("\n✓ LLM list check PASSED")
                return True
    except Exception as e:
        print(f"✗ LLM list check FAILED: {e}")
        return False


async def test_websocket_conversation():
    """Test WebSocket conversation with latency measurement"""
    print("\n" + "="*60)
    print("TEST 4: WebSocket Conversation")
    print("="*60)
    
    latencies = {
        "stt": [],
        "llm": [],
        "tts": [],
        "total": []
    }
    
    try:
        async with aiohttp.ClientSession() as session:
            ws = await session.ws_connect(f"{WS_URL}/ws/")
            
            # Send a test message
            test_message = "Hello, how are you?"
            print(f"Sending: {test_message}")
            
            start_time = time.time()
            
            # Send the message
            await ws.send_str(json.dumps({
                "type": "conversation_message",
                "message": test_message
            }))
            
            # Wait for responses
            response_count = 0
            max_responses = 10
            
            async for msg in ws:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    data = json.loads(msg.data)
                    response_type = data.get("type", "unknown")
                    
                    elapsed = (time.time() - start_time) * 1000  # ms
                    
                    print(f"Received: {response_type} ({elapsed:.0f}ms)")
                    
                    # Track latency based on response type
                    if "transcript" in response_type or "stt" in response_type:
                        latencies["stt"].append(elapsed)
                    elif "llm" in response_type or "model" in response_type:
                        latencies["llm"].append(elapsed)
                    elif "audio" in response_type or "tts" in response_type:
                        latencies["tts"].append(elapsed)
                    
                    response_count += 1
                    
                    if response_count >= max_responses:
                        break
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    break
            
            await ws.close()
            
            # Calculate averages
            print("\nLatency Results:")
            for key, values in latencies.items():
                if values:
                    avg = sum(values) / len(values)
                    print(f"  {key.upper()}: {avg:.0f}ms (avg of {len(values)} responses)")
                else:
                    print(f"  {key.upper()}: No data")
            
            print("\n✓ WebSocket conversation PASSED")
            return True
            
    except Exception as e:
        print(f"✗ WebSocket conversation FAILED: {e}")
        return False


async def test_llm_usage_calculate():
    """Test LLM usage calculation endpoint"""
    print("\n" + "="*60)
    print("TEST 5: LLM Usage Calculate")
    print("="*60)
    
    try:
        async with aiohttp.ClientSession() as session:
            # Example usage data
            usage_data = {
                "character_count": 1000,
                "tts_character_count": 500,
                "llm_token_count": 100
            }
            
            async with session.post(
                f"{SERVER_URL}/llm-usage/calculate",
                json=usage_data
            ) as response:
                data = await response.json()
                print(f"Status: {response.status}")
                print(f"Response: {json.dumps(data, indent=2)}")
                print("\n✓ LLM usage calculation PASSED")
                return True
    except Exception as e:
        print(f"✗ LLM usage calculation FAILED: {e}")
        return False


async def run_all_tests():
    """Run all tests"""
    print("\n" + "="*60)
    print("ELEVENLABS CONVAI SERVER TEST SUITE")
    print("="*60)
    print(f"Server URL: {SERVER_URL}")
    print(f"WebSocket URL: {WS_URL}")
    print(f"Time: {datetime.now().isoformat()}")
    
    results = []
    
    # Run all tests
    results.append(("Health Check", await test_health_check()))
    results.append(("Metrics", await test_metrics()))
    results.append(("LLM List", await test_llm_list()))
    results.append(("WebSocket Conversation", await test_websocket_conversation()))
    results.append(("LLM Usage Calculate", await test_llm_usage_calculate()))
    
    # Print summary
    print("\n" + "="*60)
    print("TEST SUMMARY")
    print("="*60)
    
    passed = 0
    failed = 0
    
    for name, result in results:
        status = "✓ PASSED" if result else "✗ FAILED"
        print(f"{name}: {status}")
        if result:
            passed += 1
        else:
            failed += 1
    
    print(f"\nTotal: {passed} passed, {failed} failed")
    
    return failed == 0


if __name__ == "__main__":
    success = asyncio.run(run_all_tests())
    exit(0 if success else 1)
