import asyncio
import logging
import os
from typing import Dict

from aiohttp import ClientSession, WSMsgType, web

from .env_load import load_agent_runtime_dotenv

load_agent_runtime_dotenv()

log = logging.getLogger("mcube.proxy")

WEBHOOK_UPSTREAM = os.getenv("MCUBE_WEBHOOK_UPSTREAM", "http://127.0.0.1:8002")
WS_UPSTREAM_HOST = os.getenv("MCUBE_WS_UPSTREAM_HOST", "127.0.0.1")
WS_UPSTREAM_PORT = int(os.getenv("MCUBE_WS_UPSTREAM_PORT", "9001"))

# This proxy listens on one local port so you only need one ngrok tunnel.
PROXY_PORT = int(os.getenv("MCUBE_PROXY_PORT", "8088"))


async def _proxy_http(request: web.Request) -> web.StreamResponse:
    """
    Forward all HTTP requests to the webhook server upstream.
    Used for: /webhooks/mcube, /api/mcube/outbound-call, /health, etc.
    """
    upstream_url = f"{WEBHOOK_UPSTREAM}{request.rel_url}"

    body = await request.read()

    headers: Dict[str, str] = {}
    for k, v in request.headers.items():
        # Avoid hop-by-hop headers that break proxying.
        if k.lower() in {"host", "connection", "upgrade"}:
            continue
        headers[k] = v

    async with ClientSession() as session:
        async with session.request(
            request.method,
            upstream_url,
            headers=headers,
            data=body if body else None,
            allow_redirects=False,
        ) as resp:
            resp_body = await resp.read()
            proxy_resp = web.Response(
                status=resp.status,
                body=resp_body,
            )
            return proxy_resp


async def _proxy_ws(request: web.Request) -> web.StreamResponse:
    """
    Proxy WebSocket traffic from the public URL to the MCube WS bridge.
    Expects paths like: /bid/websocket/{callId}
    """
    ws_server = web.WebSocketResponse()
    await ws_server.prepare(request)

    # Preserve the full path/query when proxying.
    rel_url = str(request.rel_url)  # includes path + query
    upstream_ws_url = f"ws://{WS_UPSTREAM_HOST}:{WS_UPSTREAM_PORT}{rel_url}"
    log.info("proxy ws connect path=%s upstream=%s", rel_url, upstream_ws_url)

    async with ClientSession() as session:
        async with session.ws_connect(upstream_ws_url) as ws_upstream:
            async def to_upstream():
                try:
                    async for msg in ws_server:
                        if msg.type == WSMsgType.TEXT:
                            try:
                                await ws_upstream.send_str(msg.data)
                            except Exception:
                                # Upstream may close first; stop proxying to avoid noisy crashes.
                                return
                        elif msg.type == WSMsgType.BINARY:
                            try:
                                await ws_upstream.send_bytes(msg.data)
                            except Exception:
                                return
                        elif msg.type == WSMsgType.CLOSE:
                            await ws_upstream.close()
                            return
                finally:
                    # Ensure the upstream side is closed when client side closes.
                    try:
                        await ws_upstream.close()
                    except Exception:
                        pass

            async def to_server():
                try:
                    async for msg in ws_upstream:
                        if msg.type == WSMsgType.TEXT:
                            try:
                                await ws_server.send_str(msg.data)
                            except Exception:
                                return
                        elif msg.type == WSMsgType.BINARY:
                            try:
                                await ws_server.send_bytes(msg.data)
                            except Exception:
                                return
                        elif msg.type == WSMsgType.CLOSE:
                            await ws_server.close()
                            return
                finally:
                    try:
                        await ws_server.close()
                    except Exception:
                        pass

            await asyncio.gather(to_upstream(), to_server(), return_exceptions=True)

    await ws_server.close()
    return ws_server


def create_app() -> web.Application:
    app = web.Application()
    app.router.add_route("*", "/webhooks/mcube", _proxy_http)
    app.router.add_route("*", "/webhooks/{tail:.*}", _proxy_http)
    app.router.add_route("*", "/api/mcube/outbound-call", _proxy_http)
    app.router.add_route("*", "/api/mcube/{tail:.*}", _proxy_http)

    # WebSocket route. aiohttp router supports wildcard tail.
    # MCube environments vary: some use `/bid/websocket/{callId}`, others use `/ws/{callId}`.
    for prefix in ("/bid/websocket", "/ws"):
        app.router.add_route("GET", f"{prefix}/{{call_id}}", _proxy_ws)
        app.router.add_route("GET", f"{prefix}/{{call_id}}/{{tail:.*}}", _proxy_ws)

    # Some MCube deployments use a business-id prefixed websocket route:
    #   /<business_id>/websocket/<call_id>
    # Example: /8028/websocket/test_cb_01
    app.router.add_route("GET", "/{business_id}/websocket/{call_id}", _proxy_ws)
    app.router.add_route("GET", "/{business_id}/websocket/{call_id}/{tail:.*}", _proxy_ws)
    return app


async def main() -> None:
    logging.basicConfig(level=logging.INFO)
    app = create_app()
    runner = web.AppRunner(app)
    await runner.setup()
    site = web.TCPSite(runner, host="0.0.0.0", port=PROXY_PORT)
    await site.start()
    log.info("mcube proxy listening on :%s", PROXY_PORT)

    # Run forever.
    await asyncio.Event().wait()


if __name__ == "__main__":
    asyncio.run(main())

