import logging
from fastapi import APIRouter, Depends, HTTPException

from app.core.observability import INGESTED_CHUNKS, RETRIEVAL_QUERIES
from app.models.schemas import (
    AgentReportResponse,
    IngestCallRequest,
    IngestCallResponse,
    QueryRequest,
    QueryResponse,
    ScoreCallRequest,
    ScoreCallResponse,
    ContextChunk,
)
from app.services.ingestion_service import IngestionService
from app.services.retrieval_service import RetrievalService
from app.services.llm_service import LLMService
from app.services.scoring_service import ScoringService
from app.services.report_service import ReportService

logger = logging.getLogger(__name__)
router = APIRouter()


class ServiceContainer:
    ingestion: IngestionService
    retrieval: RetrievalService
    llm: LLMService
    scoring: ScoringService
    report: ReportService


_container: ServiceContainer | None = None


def set_container(c: ServiceContainer):
    global _container
    _container = c


def get_container() -> ServiceContainer:
    if _container is None:
        raise RuntimeError('Service container not initialized')
    return _container


@router.post('/ingest_call', response_model=IngestCallResponse)
def ingest_call(payload: IngestCallRequest, c: ServiceContainer = Depends(get_container)):
    chunks = c.ingestion.ingest_call(
        tenant_id=payload.tenant_id,
        call_id=payload.call_id,
        agent_id=payload.agent_id,
        timestamp=payload.timestamp,
        tags=payload.tags,
        transcript=[t.model_dump() for t in payload.transcript],
    )
    INGESTED_CHUNKS.labels(tenant_id=payload.tenant_id).inc(chunks)
    return IngestCallResponse(tenant_id=payload.tenant_id, call_id=payload.call_id, chunks_ingested=chunks)


@router.post('/query', response_model=QueryResponse)
def query(payload: QueryRequest, c: ServiceContainer = Depends(get_container)):
    chunks = c.retrieval.retrieve(
        tenant_id=payload.tenant_id,
        query=payload.query,
        tags=payload.tags,
        top_k=payload.top_k,
        use_keyword=payload.use_keyword,
    )
    RETRIEVAL_QUERIES.labels(tenant_id=payload.tenant_id).inc()

    context_text = '\n\n'.join([f"[{x.speaker}] {x.text}" for x in chunks])
    answer = c.llm.answer(
        f"Use this context to answer the user query.\n\nQuery: {payload.query}\n\nContext:\n{context_text}"
    )

    return QueryResponse(
        answer=answer,
        context=[
            ContextChunk(
                chunk_id=x.chunk_id,
                call_id=x.call_id,
                agent_id=x.agent_id,
                speaker=x.speaker,
                text=x.text,
                score=x.score,
                metadata=x.metadata,
            )
            for x in chunks
        ],
    )


@router.post('/score_call', response_model=ScoreCallResponse)
def score_call(payload: ScoreCallRequest, c: ServiceContainer = Depends(get_container)):
    try:
        score_payload = c.scoring.score_call(payload.tenant_id, payload.call_id, payload.agent_id)
    except Exception as exc:
        logger.exception('score_call failed')
        raise HTTPException(status_code=500, detail=str(exc))

    return ScoreCallResponse(tenant_id=payload.tenant_id, call_id=payload.call_id, score_payload=score_payload)


@router.get('/agent_report', response_model=AgentReportResponse)
def agent_report(tenant_id: str, agent_id: str, c: ServiceContainer = Depends(get_container)):
    report = c.report.agent_report(tenant_id, agent_id)
    return AgentReportResponse(**report)
