from contextlib import contextmanager
from typing import Any

import pymysql
from pymysql.cursors import DictCursor


class MySQLRepository:
    def __init__(self, host: str, port: int, user: str, password: str, db: str):
        self.cfg = {
            'host': host,
            'port': port,
            'user': user,
            'password': password,
            'database': db,
            'cursorclass': DictCursor,
            'autocommit': True,
            'charset': 'utf8mb4',
        }

    @contextmanager
    def conn(self):
        c = pymysql.connect(**self.cfg)
        try:
            yield c
        finally:
            c.close()

    def save_call_ingestion(self, tenant_id: str, call_id: str, agent_id: str, chunk_count: int, tags: list[str]) -> None:
        with self.conn() as c:
            cur = c.cursor()
            cur.execute(
                """
                INSERT INTO calls (tenant_id, call_id, agent_id, chunk_count, tags)
                VALUES (%s, %s, %s, %s, %s)
                ON DUPLICATE KEY UPDATE
                    agent_id = VALUES(agent_id),
                    chunk_count = VALUES(chunk_count),
                    tags = VALUES(tags),
                    updated_at = CURRENT_TIMESTAMP
                """,
                (tenant_id, call_id, agent_id, chunk_count, ','.join(tags)),
            )

    def get_call_chunks_text(self, tenant_id: str, call_id: str) -> list[dict[str, Any]]:
        with self.conn() as c:
            cur = c.cursor()
            cur.execute(
                """
                SELECT chunk_id, speaker, text, timestamp
                FROM call_chunks
                WHERE tenant_id = %s AND call_id = %s
                ORDER BY chunk_index
                """,
                (tenant_id, call_id),
            )
            return cur.fetchall()

    def save_chunk_metadata(self, tenant_id: str, call_id: str, agent_id: str, chunk_index: int, chunk_id: str, speaker: str, text: str, timestamp):
        with self.conn() as c:
            cur = c.cursor()
            cur.execute(
                """
                INSERT INTO call_chunks (tenant_id, call_id, agent_id, chunk_index, chunk_id, speaker, text, timestamp)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                ON DUPLICATE KEY UPDATE
                    speaker = VALUES(speaker),
                    text = VALUES(text),
                    timestamp = VALUES(timestamp)
                """,
                (tenant_id, call_id, agent_id, chunk_index, chunk_id, speaker, text, timestamp),
            )

    def save_score(self, tenant_id: str, call_id: str, agent_id: str, score_json: dict):
        with self.conn() as c:
            cur = c.cursor()
            cur.execute(
                """
                INSERT INTO call_scores (tenant_id, call_id, agent_id, score_json)
                VALUES (%s, %s, %s, %s)
                ON DUPLICATE KEY UPDATE
                    score_json = VALUES(score_json),
                    updated_at = CURRENT_TIMESTAMP
                """,
                (tenant_id, call_id, agent_id, str(score_json)),
            )

    def get_agent_latest_scores(self, tenant_id: str, agent_id: str, limit: int = 50):
        with self.conn() as c:
            cur = c.cursor()
            cur.execute(
                """
                SELECT score_json
                FROM call_scores
                WHERE tenant_id = %s AND agent_id = %s
                ORDER BY updated_at DESC
                LIMIT %s
                """,
                (tenant_id, agent_id, limit),
            )
            return cur.fetchall()
