#!/usr/bin/env python3
from __future__ import annotations

import socket
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from urllib.parse import unquote, urlparse


@dataclass(frozen=True)
class DbTarget:
    name: str
    url: str


@dataclass(frozen=True)
class ParsedMysqlUrl:
    host: str
    port: int
    user: str
    password: str
    db: str


def _read_dotenv(path: Path) -> dict[str, str]:
    out: dict[str, str] = {}
    for raw in path.read_text(encoding="utf-8").splitlines():
        line = raw.strip()
        if not line or line.startswith("#"):
            continue
        if "=" not in line:
            continue
        k, v = line.split("=", 1)
        out[k.strip()] = v.strip().strip('"').strip("'")
    return out


def _parse_mysql_url(url: str) -> ParsedMysqlUrl:
    p = urlparse(url)
    if p.scheme not in ("mysql", "mysql+pymysql"):
        raise ValueError(f"unsupported scheme {p.scheme!r}")

    host = p.hostname or ""
    if not host:
        raise ValueError("missing host")
    port = int(p.port or 3306)

    user = unquote(p.username or "")
    password = unquote(p.password or "")
    db = (p.path or "").lstrip("/")
    if not user or not db:
        raise ValueError("missing user or database name")

    return ParsedMysqlUrl(host=host, port=port, user=user, password=password, db=db)


def _mask_db_url(url: str) -> str:
    try:
        cfg = _parse_mysql_url(url)
        pw_part = ":***" if cfg.password else ""
        port_part = f":{cfg.port}" if cfg.port else ""
        return f"mysql://{cfg.user}{pw_part}@{cfg.host}{port_part}/{cfg.db}"
    except Exception:
        return "<unparseable>"


def _resolve(host: str) -> list[str]:
    infos = socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP)
    ips: list[str] = []
    for info in infos:
        sockaddr = info[4]
        ip = sockaddr[0]
        if ip not in ips:
            ips.append(ip)
    return ips


def _tcp_probe(host: str, port: int, timeout_s: float = 5.0) -> Optional[str]:
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(timeout_s)
    try:
        s.connect((host, port))
        return None
    except Exception as e:
        return f"{type(e).__name__}: {e}"
    finally:
        try:
            s.close()
        except Exception:
            pass


def _mysql_probe_optional(cfg: ParsedMysqlUrl, timeout_s: int = 5) -> Optional[str]:
    """
    Optional MySQL auth probe.
    Runs only if `pymysql` is available in this Python environment.
    """
    try:
        import pymysql  # type: ignore
    except Exception:
        return "SKIP (pymysql not installed in this python)"

    try:
        conn = pymysql.connect(
            host=cfg.host,
            port=cfg.port,
            user=cfg.user,
            password=cfg.password,
            database=cfg.db,
            connect_timeout=timeout_s,
            read_timeout=timeout_s,
            write_timeout=timeout_s,
            cursorclass=pymysql.cursors.DictCursor,
        )
        try:
            with conn.cursor() as cur:
                cur.execute("SELECT 1 AS ok")
                _ = cur.fetchone()
        finally:
            conn.close()
        return None
    except Exception as e:
        return f"{type(e).__name__}: {e}"


def main() -> int:
    backend_dir = Path(__file__).resolve().parents[1]
    env_path = backend_dir / ".env"
    if not env_path.exists():
        print(f"ERROR: .env not found at {env_path}", file=sys.stderr)
        return 2

    env = _read_dotenv(env_path)

    targets: list[DbTarget] = []
    for key in ("DATABASE_MASTER_URL", "DATABASE_CLUSTER_URL"):
        url = env.get(key)
        if url:
            targets.append(DbTarget(name=key, url=url))

    if not targets:
        print("ERROR: No DATABASE_MASTER_URL / DATABASE_CLUSTER_URL found in backend/.env", file=sys.stderr)
        return 2

    any_fail = False
    for t in targets:
        print("=" * 80)
        print(f"{t.name} = {_mask_db_url(t.url)}")

        try:
            cfg = _parse_mysql_url(t.url)
        except Exception as e:
            print(f"PARSE: FAIL ({type(e).__name__}: {e})")
            any_fail = True
            continue

        # 1) DNS
        try:
            ips = _resolve(cfg.host)
            print(f"DNS: OK ({', '.join(ips)})")
        except Exception as e:
            print(f"DNS: FAIL ({type(e).__name__}: {e})")
            any_fail = True
            continue

        # 2) Raw TCP connect (tells us if firewall/routing blocks 3306)
        tcp_err = _tcp_probe(cfg.host, cfg.port, timeout_s=5.0)
        if tcp_err:
            print(f"TCP: FAIL ({cfg.host}:{cfg.port}) ({tcp_err})")
            # If TCP fails, MySQL will fail too; no point continuing.
            any_fail = True
            continue
        print(f"TCP: OK ({cfg.host}:{cfg.port})")

        # 3) MySQL auth + simple query
        mysql_err = _mysql_probe_optional(cfg, timeout_s=5)
        if mysql_err:
            print(f"MYSQL: FAIL ({mysql_err})")
            any_fail = True
            continue
        print("MYSQL: OK (connected + SELECT 1)")

    print("=" * 80)
    return 1 if any_fail else 0


if __name__ == "__main__":
    raise SystemExit(main())