from sentence_transformers import CrossEncoder


class RerankerService:
    def __init__(self, model_name: str, device: str):
        self.model_name = model_name
        self.device = device
        self.model = None

    def _get_model(self):
        if self.model is None:
            self.model = CrossEncoder(self.model_name, device=self.device)
        return self.model

    def rerank(self, query: str, texts: list[str]) -> list[float]:
        if not texts:
            return []
        pairs = [[query, text] for text in texts]
        scores = self._get_model().predict(pairs)
        return [float(s) for s in scores]
