from django.db import connections
from django.core.files.storage import FileSystemStorage
from django.utils import timezone
from rest_framework import status, viewsets
from rest_framework.response import Response

from .models import Chunk, Document, KnowledgeBase
from .serializers import ChunkSerializer, DocumentSerializer, KnowledgeBaseSerializer
from apps.agents.models import AgentConfig
from apps.cluster.dynamic_tables import _q_ident, ensure_business_cluster_tables


def _cluster_kb_insert_row(
    table: str,
    business_id_int: int,
    bot_id_val,
    doc_id,
    doc_name: str,
    doc_url: str,
) -> int:
    """
    Insert into `{bid}_knowledgebase` using only columns that exist.
    Sets created_at/updated_at explicitly when present (legacy tables may omit DEFAULT).
    """
    with connections["cluster"].cursor() as cur:
        cur.execute(f"SELECT * FROM {_q_ident(table)} LIMIT 0")
        col_names = [c[0] for c in cur.description]
    cols_set = set(col_names)
    now = timezone.now()
    row_vals: dict = {
        "doc_id": str(doc_id or ""),
        "doc_name": doc_name,
        "doc_url": str(doc_url or ""),
        "bot_id": bot_id_val,
        "business_id": business_id_int,
    }
    if "created_at" in cols_set:
        row_vals["created_at"] = now
    if "updated_at" in cols_set:
        row_vals["updated_at"] = now

    insert_keys = [k for k in row_vals if k in cols_set and k != "id"]
    if not insert_keys:
        raise ValueError("knowledgebase table has no insertable columns")
    placeholders = ", ".join(["%s"] * len(insert_keys))
    col_sql = ", ".join(_q_ident(k) for k in insert_keys)
    sql = f"INSERT INTO {_q_ident(table)} ({col_sql}) VALUES ({placeholders})"
    params = [row_vals[k] for k in insert_keys]

    with connections["cluster"].cursor() as cur:
        cur.execute(sql, params)
        return cur.lastrowid


def _normalize_cluster_kb_row(row: dict) -> dict:
    """Shape cluster `{bid}_knowledgebase` rows like `Document` API responses."""
    out = dict(row)
    out["name"] = out.get("doc_name") or ""
    out["file_path"] = out.get("doc_url") or ""
    if out.get("doc_url") and str(out["doc_url"]).startswith(("http://", "https://")):
        out["web_url"] = out["doc_url"]
    return out


class KnowledgeBaseViewSet(viewsets.ModelViewSet):
    queryset = KnowledgeBase.objects.all().order_by("-updated_at")
    serializer_class = KnowledgeBaseSerializer


class DocumentViewSet(viewsets.ModelViewSet):
    queryset = Document.objects.select_related("knowledge_base").all().order_by("-updated_at")
    serializer_class = DocumentSerializer

    def list(self, request, *args, **kwargs):
        """
        With `business_id` + `bot_id`, list rows from cluster `{business_id}_knowledgebase`.
        Otherwise list master `Document` rows.
        """
        bid_raw = request.query_params.get("business_id")
        bot_raw = request.query_params.get("bot_id")
        if bid_raw not in (None, "") and bot_raw not in (None, "",):
            try:
                bid = int(bid_raw)
                bot_id = int(bot_raw)
                ensure_business_cluster_tables(bid)
                table = f"{bid}_knowledgebase"
                with connections["cluster"].cursor() as cur:
                    cur.execute(
                        f"SELECT * FROM {_q_ident(table)} WHERE bot_id = %s ORDER BY id DESC",
                        [bot_id],
                    )
                    cols = [c[0] for c in cur.description]
                    rows = [_normalize_cluster_kb_row(dict(zip(cols, r))) for r in cur.fetchall()]
                return Response(rows)
            except Exception as e:
                return Response(
                    {"message": "Failed to load knowledge documents from cluster.", "error": str(e)},
                    status=status.HTTP_500_INTERNAL_SERVER_ERROR,
                )
        return super().list(request, *args, **kwargs)

    def retrieve(self, request, *args, **kwargs):
        bid_raw = request.query_params.get("business_id")
        pk = kwargs.get("pk")
        if bid_raw not in (None, "") and pk is not None:
            try:
                bid = int(bid_raw)
                ensure_business_cluster_tables(bid)
                table = f"{bid}_knowledgebase"
                with connections["cluster"].cursor() as cur:
                    cur.execute(f"SELECT * FROM {_q_ident(table)} WHERE id = %s", [pk])
                    row = cur.fetchone()
                    if not row:
                        return Response({"detail": "Not found."}, status=status.HTTP_404_NOT_FOUND)
                    cols = [c[0] for c in cur.description]
                return Response(_normalize_cluster_kb_row(dict(zip(cols, row))))
            except Exception as e:
                return Response(
                    {"message": "Failed to load document from cluster.", "error": str(e)},
                    status=status.HTTP_500_INTERNAL_SERVER_ERROR,
                )
        return super().retrieve(request, *args, **kwargs)

    def destroy(self, request, *args, **kwargs):
        bid_raw = request.query_params.get("business_id") or request.data.get("business_id")
        pk = kwargs.get("pk")
        if bid_raw not in (None, "") and pk is not None:
            try:
                bid = int(bid_raw)
                ensure_business_cluster_tables(bid)
                table = f"{bid}_knowledgebase"
                with connections["cluster"].cursor() as cur:
                    cur.execute(f"DELETE FROM {_q_ident(table)} WHERE id = %s", [pk])
                return Response(status=status.HTTP_204_NO_CONTENT)
            except Exception as e:
                return Response(
                    {"message": "Failed to delete document from cluster.", "error": str(e)},
                    status=status.HTTP_500_INTERNAL_SERVER_ERROR,
                )
        return super().destroy(request, *args, **kwargs)

    def create(self, request, *args, **kwargs):
        """
        With `business_id`: store metadata + file path only in cluster `{business_id}_knowledgebase`
        (livekitvoicebot_cluster). Skips master `Document` / `KnowledgeBase` rows.

        Without `business_id`: legacy master-DB behavior.
        """
        data = request.data.copy()
        business_id_raw = data.get("business_id")

        if business_id_raw not in (None, "",):
            try:
                business_id_int = int(business_id_raw)
            except (TypeError, ValueError):
                return Response(
                    {"business_id": ["Invalid business_id."]},
                    status=status.HTTP_400_BAD_REQUEST,
                )

            bot_id_raw = data.get("bot_id")
            doc_id = data.get("doc_id")
            doc_name = str(data.get("name") or "")
            uploaded = request.FILES.get("file")

            ensure_business_cluster_tables(business_id_int)
            table = f"{business_id_int}_knowledgebase"

            doc_url = ""
            if uploaded is not None:
                bid = str(business_id_int)
                safe_prefix = str(doc_id or int(timezone.now().timestamp()))
                storage = FileSystemStorage()
                saved_name = storage.save(f"kb_uploads/{bid}/{safe_prefix}_{uploaded.name}", uploaded)
                doc_url = storage.url(saved_name) if hasattr(storage, "url") else saved_name
            else:
                doc_url = (
                    str(data.get("file_path") or data.get("web_url") or data.get("doc_url") or "")
                    or ""
                )

            bot_id_val = int(bot_id_raw) if str(bot_id_raw or "").strip() != "" else None

            new_id = _cluster_kb_insert_row(
                table,
                business_id_int,
                bot_id_val,
                doc_id,
                doc_name,
                str(doc_url or ""),
            )
            with connections["cluster"].cursor() as cur:
                cur.execute(f"SELECT * FROM {_q_ident(table)} WHERE id = %s", [new_id])
                fetched = cur.fetchone()
                cols = [c[0] for c in cur.description]

            row = _normalize_cluster_kb_row(dict(zip(cols, fetched)))
            row["source_type"] = data.get("source_type") or ("file" if uploaded else "web_url")
            row["knowledge_base"] = None
            headers = self.get_success_headers({"id": row.get("id")})
            return Response(row, status=status.HTTP_201_CREATED, headers=headers)

        # Legacy: master DB only (no cluster table).
        kb_val = data.get("knowledge_base")
        bot_id_raw = data.get("bot_id")
        doc_id = data.get("doc_id")

        if kb_val in (None, "", 0, "0"):
            kb_obj = None
            if bot_id_raw not in (None, "", 0, "0"):
                try:
                    bot = AgentConfig.objects.select_related("knowledge_base").get(pk=int(bot_id_raw))
                    kb_obj = bot.knowledge_base
                except Exception:
                    kb_obj = None
            if kb_obj is None:
                kb_obj = KnowledgeBase.objects.create(name=f"KB {timezone.now().isoformat()}")
            data["knowledge_base"] = kb_obj.pk

        uploaded = request.FILES.get("file")
        if uploaded is not None:
            bid = str(data.get("business_id") or "unknown")
            safe_prefix = str(doc_id or int(timezone.now().timestamp()))
            storage = FileSystemStorage()
            saved_name = storage.save(f"kb_uploads/{bid}/{safe_prefix}_{uploaded.name}", uploaded)
            data["file_path"] = storage.url(saved_name) if hasattr(storage, "url") else saved_name
            data["mime_type"] = data.get("mime_type") or getattr(uploaded, "content_type", "") or ""
            data["file_size"] = data.get("file_size") or getattr(uploaded, "size", None)

        serializer = self.get_serializer(data=data)
        serializer.is_valid(raise_exception=True)
        self.perform_create(serializer)
        headers = self.get_success_headers(serializer.data)
        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)


class ChunkViewSet(viewsets.ModelViewSet):
    queryset = Chunk.objects.select_related("knowledge_base", "document").all().order_by("-id")
    serializer_class = ChunkSerializer
