import logging

from sentence_transformers import SentenceTransformer

from app.repositories.redis_cache import RedisCacheRepository

logger = logging.getLogger(__name__)


class EmbeddingService:
    def __init__(self, model_name: str, device: str, cache_repo: RedisCacheRepository, cache_ttl: int):
        self.model_name = model_name
        self.device = device
        self.model = None
        self.cache = cache_repo
        self.cache_ttl = cache_ttl

    def _get_model(self):
        if self.model is None:
            self.model = SentenceTransformer(self.model_name, device=self.device)
        return self.model

    def embed_text(self, text: str) -> list[float]:
        cache_key = self.cache.embedding_key(self.model_name, text)
        cached = self.cache.get_json(cache_key)
        if cached:
            return cached

        vec = self._get_model().encode(text, normalize_embeddings=True).tolist()
        self.cache.set_json(cache_key, vec, self.cache_ttl)
        return vec

    def embed_texts(self, texts: list[str]) -> list[list[float]]:
        vectors: list[list[float]] = []
        missing_indices: list[int] = []
        missing_texts: list[str] = []

        for i, txt in enumerate(texts):
            key = self.cache.embedding_key(self.model_name, txt)
            cached = self.cache.get_json(key)
            if cached:
                vectors.append(cached)
            else:
                vectors.append([])
                missing_indices.append(i)
                missing_texts.append(txt)

        if missing_texts:
            generated = self._get_model().encode(missing_texts, normalize_embeddings=True)
            for i, vec in zip(missing_indices, generated):
                values = vec.tolist()
                vectors[i] = values
                self.cache.set_json(self.cache.embedding_key(self.model_name, texts[i]), values, self.cache_ttl)

        return vectors
