import math
from dataclasses import dataclass

from app.repositories.qdrant_repository import QdrantRepository
from app.repositories.redis_cache import RedisCacheRepository
from app.services.embedding_service import EmbeddingService
from app.services.reranker_service import RerankerService


@dataclass
class RetrievedChunk:
    chunk_id: str
    call_id: str
    agent_id: str
    speaker: str
    text: str
    score: float
    metadata: dict


class RetrievalService:
    def __init__(
        self,
        qdrant_repo: QdrantRepository,
        emb_service: EmbeddingService,
        reranker_service: RerankerService,
        cache_repo: RedisCacheRepository,
        retrieval_cache_ttl: int,
        candidate_k: int,
        default_top_k: int,
        keyword_weight: float,
    ):
        self.qdrant = qdrant_repo
        self.emb = emb_service
        self.reranker = reranker_service
        self.cache = cache_repo
        self.retrieval_cache_ttl = retrieval_cache_ttl
        self.candidate_k = candidate_k
        self.default_top_k = default_top_k
        self.keyword_weight = keyword_weight

    @staticmethod
    def _keyword_score(query: str, text: str) -> float:
        q_tokens = [t for t in query.lower().split() if t]
        if not q_tokens:
            return 0.0
        txt = text.lower()
        hits = sum(1 for token in q_tokens if token in txt)
        return hits / max(len(q_tokens), 1)

    def retrieve(self, tenant_id: str, query: str, tags: list[str], top_k: int | None, use_keyword: bool) -> list[RetrievedChunk]:
        top_k = top_k or self.default_top_k
        cache_key = self.cache.retrieval_key(tenant_id, query, tags, top_k, use_keyword)
        cached = self.cache.get_json(cache_key)
        if cached:
            return [RetrievedChunk(**c) for c in cached]

        query_vec = self.emb.embed_text(query)
        hits = self.qdrant.search(query_vec, tenant_id=tenant_id, tags=tags, limit=max(self.candidate_k, top_k))

        interim: list[RetrievedChunk] = []
        for h in hits:
            payload = h.payload or {}
            score = float(h.score)
            if use_keyword:
                kw = self._keyword_score(query, payload.get('text', ''))
                score = ((1.0 - self.keyword_weight) * score) + (self.keyword_weight * kw)

            interim.append(
                RetrievedChunk(
                    chunk_id=payload.get('chunk_id', ''),
                    call_id=payload.get('call_id', ''),
                    agent_id=payload.get('agent_id', ''),
                    speaker=payload.get('speaker', 'unknown'),
                    text=payload.get('text', ''),
                    score=float(score),
                    metadata=payload,
                )
            )

        rerank_scores = self.reranker.rerank(query, [x.text for x in interim]) if interim else []
        for idx, rs in enumerate(rerank_scores):
            interim[idx].score = 0.6 * interim[idx].score + 0.4 * rs

        interim.sort(key=lambda x: x.score, reverse=True)
        result = interim[:top_k]
        self.cache.set_json(cache_key, [r.__dict__ for r in result], self.retrieval_cache_ttl)
        return result
