from typing import Any

from qdrant_client import QdrantClient
from qdrant_client.http import models as qm


class QdrantRepository:
    def __init__(self, host: str, port: int, collection_name: str, https: bool = False, api_key: str | None = None):
        self.collection_name = collection_name
        self.client = QdrantClient(host=host, port=port, https=https, api_key=api_key)

    def ensure_collection(self, vector_size: int) -> None:
        collections = [c.name for c in self.client.get_collections().collections]
        if self.collection_name in collections:
            return

        self.client.create_collection(
            collection_name=self.collection_name,
            vectors_config=qm.VectorParams(size=vector_size, distance=qm.Distance.COSINE),
        )

    def upsert_chunks(self, points: list[qm.PointStruct]) -> None:
        self.client.upsert(collection_name=self.collection_name, points=points)

    def search(self, query_vector: list[float], tenant_id: str, tags: list[str], limit: int) -> list[Any]:
        filters = [qm.FieldCondition(key='tenant_id', match=qm.MatchValue(value=tenant_id))]
        if tags:
            filters.append(qm.FieldCondition(key='tags', match=qm.MatchAny(any=tags)))

        return self.client.search(
            collection_name=self.collection_name,
            query_vector=query_vector,
            query_filter=qm.Filter(must=filters),
            limit=limit,
            with_payload=True,
            with_vectors=False,
        )
