import json
import logging
import math
import os
import re
import uuid
import hashlib
from collections import Counter
from datetime import datetime
from decimal import Decimal

import boto3
import pymysql
import requests
from pymysql.cursors import DictCursor
try:
    from qdrant_client import QdrantClient
    from qdrant_client.http import models as qm
except Exception:
    QdrantClient = None
    qm = None

logger = logging.getLogger(__name__)


class RAGHandler:
    """Multi-tenant RAG operations with conversation memory and lightweight profiling."""

    STOPWORDS = {
        "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "from", "has", "have",
        "he", "her", "his", "i", "if", "in", "is", "it", "its", "me", "my", "of", "on", "or", "our",
        "she", "that", "the", "their", "them", "they", "this", "to", "us", "was", "we", "were", "will",
        "with", "you", "your"
    }

    def __init__(self, config):
        self.config = config
        self.db_config = {
            "host": config.get("DB_HOST", "127.0.0.1"),
            "port": int(config.get("DB_PORT", 3306)),
            "user": config.get("DB_USER", "admin"),
            "password": config.get("DB_PASSWORD", ""),
            "database": config.get("DB_NAME", "voicebot_cluster"),
            "charset": "utf8mb4",
            "cursorclass": DictCursor,
            "autocommit": True,
        }

        self.rag_top_k = int(config.get("RAG_TOP_K", 8))
        self.rag_similarity_threshold = float(config.get("RAG_SIMILARITY_THRESHOLD", 0.2))
        self.rag_memory_messages = int(config.get("RAG_MEMORY_MESSAGES", 12))
        self.rag_use_qdrant = bool(config.get("RAG_USE_QDRANT", True))
        self.rag_qdrant_host = config.get("RAG_QDRANT_HOST", "127.0.0.1")
        self.rag_qdrant_port = int(config.get("RAG_QDRANT_PORT", 6333))
        self.rag_qdrant_collection_prefix = config.get("RAG_QDRANT_COLLECTION_PREFIX", "rag_chunks_")
        self.rag_qdrant_api_key = config.get("RAG_QDRANT_API_KEY", "") or None
        self.rag_qdrant_https = bool(config.get("RAG_QDRANT_HTTPS", False))
        self.rag_use_redis_cache = bool(config.get("RAG_USE_REDIS_CACHE", False))
        self.rag_embed_cache_ttl = int(config.get("RAG_EMBED_CACHE_TTL_SECONDS", 86400))
        self.rag_retrieval_cache_ttl = int(config.get("RAG_RETRIEVAL_CACHE_TTL_SECONDS", 300))
        self.rag_memory_ttl = int(config.get("RAG_MEMORY_TTL_SECONDS", 2592000))

        self.aws_region = os.getenv("AWS_REGION", config.get("AWS_REGION", "us-east-1"))
        self.aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID", config.get("AWS_ACCESS_KEY_ID", ""))
        self.aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", config.get("AWS_SECRET_ACCESS_KEY", ""))
        self.rag_chat_model = os.getenv("RAG_CHAT_MODEL", config.get("RAG_CHAT_MODEL", config.get("AWS_NOVA_MODEL", "amazon.nova-lite-v1:0")))
        self.rag_embedding_model = os.getenv("RAG_EMBEDDING_MODEL", config.get("RAG_EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0"))
        self.rag_chat_provider = os.getenv("RAG_CHAT_PROVIDER", config.get("RAG_CHAT_PROVIDER", "bedrock"))
        self.ollama_base_url = os.getenv("OLLAMA_BASE_URL", config.get("OLLAMA_BASE_URL", "http://127.0.0.1:11434")).rstrip("/")
        self.ollama_timeout_seconds = int(os.getenv("OLLAMA_TIMEOUT_SECONDS", str(config.get("OLLAMA_TIMEOUT_SECONDS", 180))))

        self.bedrock_runtime = None
        if self.aws_access_key_id and self.aws_secret_access_key:
            try:
                self.bedrock_runtime = boto3.client(
                    service_name="bedrock-runtime",
                    region_name=self.aws_region,
                    aws_access_key_id=self.aws_access_key_id,
                    aws_secret_access_key=self.aws_secret_access_key,
                )
            except Exception as exc:
                logger.warning("Failed to initialize Bedrock runtime: %s", exc)

        # Redis caching is intentionally disabled; MySQL remains source-of-truth for state.
        self.redis_client = None

        self.qdrant_client = None
        if self.rag_use_qdrant and QdrantClient is not None:
            try:
                self.qdrant_client = QdrantClient(
                    host=self.rag_qdrant_host,
                    port=self.rag_qdrant_port,
                    https=self.rag_qdrant_https,
                    api_key=self.rag_qdrant_api_key,
                    check_compatibility=False,
                )
            except Exception as exc:
                logger.warning("Qdrant unavailable, continuing with MySQL vector fallback: %s", exc)
                self.qdrant_client = None

        try:
            self.ensure_tables()
        except Exception as exc:
            logger.warning("Could not ensure RAG tables at startup: %s", exc)

    def get_connection(self):
        return pymysql.connect(**self.db_config)

    @staticmethod
    def _safe_identifier(name):
        if not re.match(r"^[A-Za-z0-9_]+$", str(name)):
            raise ValueError(f"Unsafe identifier: {name}")
        return str(name)

    def _table_exists(self, cursor, table_name):
        cursor.execute("SHOW TABLES LIKE %s", (table_name,))
        return cursor.fetchone() is not None

    def _table_columns(self, cursor, table_name):
        cursor.execute(f"SHOW COLUMNS FROM `{self._safe_identifier(table_name)}`")
        return {row["Field"] for row in cursor.fetchall()}

    def ensure_tables(self):
        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_documents (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    source_id VARCHAR(255) NOT NULL,
                    title VARCHAR(500),
                    source_type VARCHAR(100),
                    source_uri TEXT,
                    metadata JSON,
                    is_active BOOLEAN DEFAULT TRUE,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    UNIQUE KEY uniq_bid_source (bid, source_id),
                    INDEX idx_bid_active (bid, is_active),
                    INDEX idx_created_at (created_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_chunks (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    document_id BIGINT NOT NULL,
                    chunk_id VARCHAR(255) NOT NULL,
                    chunk_index INT DEFAULT 0,
                    content LONGTEXT NOT NULL,
                    token_count INT,
                    metadata JSON,
                    embedding JSON NOT NULL,
                    embedding_dim INT NOT NULL,
                    embedding_norm DOUBLE NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    UNIQUE KEY uniq_bid_doc_chunk (bid, document_id, chunk_id),
                    INDEX idx_bid_doc (bid, document_id),
                    INDEX idx_bid_created (bid, created_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_conversations (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    conversation_id VARCHAR(100) NOT NULL,
                    user_id VARCHAR(100) NOT NULL,
                    metadata JSON,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    UNIQUE KEY uniq_bid_conversation (bid, conversation_id),
                    INDEX idx_bid_user (bid, user_id),
                    INDEX idx_updated_at (updated_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_messages (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    conversation_id VARCHAR(100) NOT NULL,
                    user_id VARCHAR(100) NOT NULL,
                    role ENUM('system','user','assistant') NOT NULL,
                    content LONGTEXT NOT NULL,
                    metadata JSON,
                    retrieved_chunk_ids JSON,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    INDEX idx_bid_conversation (bid, conversation_id),
                    INDEX idx_bid_user_created (bid, user_id, created_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS rag_user_profiles (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    bid VARCHAR(50) NOT NULL,
                    user_id VARCHAR(100) NOT NULL,
                    profile JSON NOT NULL,
                    profile_version INT DEFAULT 1,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    UNIQUE KEY uniq_bid_user_profile (bid, user_id),
                    INDEX idx_bid_updated (bid, updated_at)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
            )
            conn.commit()
        finally:
            conn.close()

    @staticmethod
    def _parse_json(value, fallback):
        if value is None:
            return fallback
        if isinstance(value, (dict, list)):
            return value
        if isinstance(value, str):
            try:
                return json.loads(value)
            except json.JSONDecodeError:
                return fallback
        return fallback

    @staticmethod
    def _ensure_embedding_list(embedding):
        if embedding is None:
            return None
        if isinstance(embedding, str):
            try:
                embedding = json.loads(embedding)
            except json.JSONDecodeError:
                return None
        if not isinstance(embedding, list):
            return None

        vector = []
        for value in embedding:
            try:
                vector.append(float(value))
            except (TypeError, ValueError):
                return None
        return vector

    @staticmethod
    def _vector_norm(vector):
        return math.sqrt(sum(value * value for value in vector))

    @staticmethod
    def _dot_product(a, b):
        return sum(x * y for x, y in zip(a, b))

    def _json_safe(self, value):
        if isinstance(value, Decimal):
            return float(value)
        if isinstance(value, datetime):
            return value.isoformat()
        if isinstance(value, dict):
            return {str(k): self._json_safe(v) for k, v in value.items()}
        if isinstance(value, list):
            return [self._json_safe(v) for v in value]
        return value

    def _cache_get(self, key):
        if not self.redis_client:
            return None
        try:
            raw = self.redis_client.get(key)
            return json.loads(raw) if raw else None
        except Exception:
            return None

    def _cache_set(self, key, value, ttl_seconds):
        if not self.redis_client:
            return
        try:
            self.redis_client.setex(key, int(ttl_seconds), json.dumps(self._json_safe(value), ensure_ascii=True))
        except Exception:
            pass

    def _embed_cache_key(self, text):
        digest = hashlib.sha256(f"{self.rag_embedding_model}|{text}".encode("utf-8")).hexdigest()
        return f"rag:emb:{digest}"

    def _retrieval_cache_key(self, bid, vector, top_k, min_similarity):
        vector_sig = hashlib.sha256(",".join([f"{v:.6f}" for v in vector[:64]]).encode("utf-8")).hexdigest()
        return f"rag:ret:{bid}:{top_k}:{min_similarity}:{vector_sig}"

    def _conversation_cache_key(self, bid, conversation_id):
        return f"rag:conv:{bid}:{conversation_id}"

    def _conversation_resume_key(self, bid, user_id, metadata):
        scope = "global"
        call_id = "all"
        if isinstance(metadata, dict):
            scope = str(metadata.get("context_scope") or metadata.get("scope") or "global")
            context = metadata.get("context")
            if isinstance(context, dict):
                call_id = str(context.get("call_id") or "all")
        return f"rag:resume:{bid}:{user_id}:{scope}:{call_id}"

    def _resolve_conversation_id(self, bid, user_id, metadata, explicit_id=None):
        if explicit_id:
            return str(explicit_id)
        key = self._conversation_resume_key(bid, user_id, metadata)
        cached = self._cache_get(key)
        if isinstance(cached, str) and cached.strip():
            return cached.strip()
        return None

    def _save_conversation_resume(self, bid, user_id, metadata, conversation_id):
        if not conversation_id:
            return
        key = self._conversation_resume_key(bid, user_id, metadata)
        self._cache_set(key, str(conversation_id), self.rag_memory_ttl)

    def _push_cached_message(self, bid, conversation_id, message_obj):
        if not self.redis_client:
            return
        key = self._conversation_cache_key(bid, conversation_id)
        try:
            self.redis_client.rpush(key, json.dumps(self._json_safe(message_obj), ensure_ascii=True))
            self.redis_client.ltrim(key, -400, -1)
            self.redis_client.expire(key, self.rag_memory_ttl)
        except Exception:
            pass

    def _qdrant_collection_name(self, bid):
        return f"{self.rag_qdrant_collection_prefix}{self._safe_identifier(str(bid))}"

    def _qdrant_point_id(self, bid, source_id, chunk_id):
        digest = hashlib.sha256(f"{bid}|{source_id}|{chunk_id}".encode("utf-8")).hexdigest()
        return digest

    def _ensure_qdrant_collection(self, bid, vector_size):
        if not self.qdrant_client or qm is None:
            return False
        collection_name = self._qdrant_collection_name(bid)
        try:
            existing = [c.name for c in self.qdrant_client.get_collections().collections]
            if collection_name not in existing:
                self.qdrant_client.create_collection(
                    collection_name=collection_name,
                    vectors_config=qm.VectorParams(size=int(vector_size), distance=qm.Distance.COSINE),
                )
            return True
        except Exception as exc:
            logger.warning("Failed to ensure Qdrant collection %s: %s", collection_name, exc)
            return False

    def _upsert_document(self, cursor, bid, document):
        source_id = str(document.get("source_id") or document.get("id") or uuid.uuid4())
        title = document.get("title")
        source_type = document.get("source_type")
        source_uri = document.get("source_uri")
        metadata = json.dumps(self._json_safe(document.get("metadata") or {}), ensure_ascii=True)

        cursor.execute(
            """
            INSERT INTO rag_documents (bid, source_id, title, source_type, source_uri, metadata, is_active)
            VALUES (%s, %s, %s, %s, %s, %s, TRUE)
            ON DUPLICATE KEY UPDATE
                title = VALUES(title),
                source_type = VALUES(source_type),
                source_uri = VALUES(source_uri),
                metadata = VALUES(metadata),
                is_active = TRUE,
                updated_at = CURRENT_TIMESTAMP
            """,
            (str(bid), source_id, title, source_type, source_uri, metadata),
        )

        cursor.execute(
            "SELECT id FROM rag_documents WHERE bid = %s AND source_id = %s",
            (str(bid), source_id),
        )
        row = cursor.fetchone()
        return row["id"], source_id

    def ingest_documents(self, bid, documents):
        if not documents:
            return {"documents": 0, "chunks": 0, "skipped": 0}

        conn = self.get_connection()
        inserted_docs = 0
        inserted_chunks = 0
        skipped_chunks = 0
        qdrant_points = []
        qdrant_ready = False

        try:
            cursor = conn.cursor()
            for document in documents:
                chunks = document.get("chunks") or []
                if not chunks:
                    continue

                document_id, source_id = self._upsert_document(cursor, bid, document)
                inserted_docs += 1

                for idx, chunk in enumerate(chunks):
                    chunk_text = (chunk.get("content") or chunk.get("text") or "").strip()
                    if not chunk_text:
                        skipped_chunks += 1
                        continue

                    embedding = self._ensure_embedding_list(chunk.get("embedding"))
                    if not embedding:
                        skipped_chunks += 1
                        continue

                    norm = self._vector_norm(embedding)
                    if norm == 0:
                        skipped_chunks += 1
                        continue

                    chunk_id = str(chunk.get("chunk_id") or f"{source_id}-chunk-{idx}")
                    chunk_index = int(chunk.get("chunk_index", idx))
                    token_count = chunk.get("token_count")
                    metadata = json.dumps(self._json_safe(chunk.get("metadata") or {}), ensure_ascii=True)

                    cursor.execute(
                        """
                        INSERT INTO rag_chunks (
                            bid, document_id, chunk_id, chunk_index, content, token_count,
                            metadata, embedding, embedding_dim, embedding_norm
                        )
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            chunk_index = VALUES(chunk_index),
                            content = VALUES(content),
                            token_count = VALUES(token_count),
                            metadata = VALUES(metadata),
                            embedding = VALUES(embedding),
                            embedding_dim = VALUES(embedding_dim),
                            embedding_norm = VALUES(embedding_norm),
                            updated_at = CURRENT_TIMESTAMP
                        """,
                        (
                            str(bid),
                            document_id,
                            chunk_id,
                            chunk_index,
                            chunk_text,
                            token_count,
                            metadata,
                            json.dumps(embedding, ensure_ascii=True),
                            len(embedding),
                            norm,
                        ),
                    )
                    inserted_chunks += 1

                    if self.qdrant_client and qm is not None:
                        if not qdrant_ready:
                            qdrant_ready = self._ensure_qdrant_collection(bid, len(embedding))
                        if qdrant_ready:
                            try:
                                payload = {
                                    "bid": str(bid),
                                    "document_id": int(document_id),
                                    "source_id": str(source_id),
                                    "chunk_id": str(chunk_id),
                                    "chunk_index": int(chunk_index),
                                    "content": chunk_text,
                                    "metadata": self._json_safe(chunk.get("metadata") or {}),
                                    "title": document.get("title"),
                                    "source_type": document.get("source_type"),
                                    "source_uri": document.get("source_uri"),
                                }
                                qdrant_points.append(
                                    qm.PointStruct(
                                        id=self._qdrant_point_id(bid, source_id, chunk_id),
                                        vector=embedding,
                                        payload=payload,
                                    )
                                )
                            except Exception as exc:
                                logger.warning("Failed to stage Qdrant point for %s/%s: %s", source_id, chunk_id, exc)

            conn.commit()
            if qdrant_points and self.qdrant_client:
                try:
                    self.qdrant_client.upsert(
                        collection_name=self._qdrant_collection_name(bid),
                        points=qdrant_points,
                    )
                except Exception as exc:
                    logger.warning("Qdrant upsert failed for bid %s; MySQL ingestion still succeeded: %s", bid, exc)
            return {
                "documents": inserted_docs,
                "chunks": inserted_chunks,
                "skipped": skipped_chunks,
            }
        except Exception:
            conn.rollback()
            raise
        finally:
            conn.close()

    def _get_or_create_conversation(self, cursor, bid, user_id, conversation_id=None, metadata=None):
        conv_id = str(conversation_id or uuid.uuid4())

        cursor.execute(
            """
            INSERT INTO rag_conversations (bid, conversation_id, user_id, metadata)
            VALUES (%s, %s, %s, %s)
            ON DUPLICATE KEY UPDATE
                metadata = COALESCE(VALUES(metadata), metadata),
                updated_at = CURRENT_TIMESTAMP
            """,
            (str(bid), conv_id, str(user_id), json.dumps(metadata or {}, ensure_ascii=True)),
        )
        return conv_id

    def save_message(self, bid, conversation_id, user_id, role, content, metadata=None, retrieved_chunk_ids=None):
        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            message_obj = {
                "role": role,
                "content": content,
                "metadata": metadata or {},
                "retrieved_chunk_ids": retrieved_chunk_ids or [],
                "created_at": datetime.utcnow().isoformat() + "Z",
            }
            cursor.execute(
                """
                INSERT INTO rag_messages (
                    bid, conversation_id, user_id, role, content, metadata, retrieved_chunk_ids
                )
                VALUES (%s, %s, %s, %s, %s, %s, %s)
                """,
                (
                    str(bid),
                    str(conversation_id),
                    str(user_id),
                    role,
                    content,
                    json.dumps(metadata or {}, ensure_ascii=True),
                    json.dumps(retrieved_chunk_ids or [], ensure_ascii=True),
                ),
            )
            cursor.execute(
                """
                UPDATE rag_conversations
                SET updated_at = CURRENT_TIMESTAMP
                WHERE bid = %s AND conversation_id = %s
                """,
                (str(bid), str(conversation_id)),
            )
            conn.commit()
            self._push_cached_message(bid, conversation_id, message_obj)
        finally:
            conn.close()

    def get_conversation_messages(self, bid, conversation_id, limit=200):
        if self.redis_client:
            key = self._conversation_cache_key(bid, conversation_id)
            try:
                cached = self.redis_client.lrange(key, -int(limit), -1)
                if cached:
                    rows = []
                    for item in cached:
                        try:
                            row = json.loads(item)
                            rows.append({
                                "role": row.get("role"),
                                "content": row.get("content"),
                                "metadata": row.get("metadata") or {},
                                "retrieved_chunk_ids": row.get("retrieved_chunk_ids") or [],
                                "created_at": row.get("created_at"),
                            })
                        except Exception:
                            continue
                    if rows:
                        return rows
            except Exception:
                pass

        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                SELECT role, content, metadata, retrieved_chunk_ids, created_at
                FROM rag_messages
                WHERE bid = %s AND conversation_id = %s
                ORDER BY id DESC
                LIMIT %s
                """,
                (str(bid), str(conversation_id), int(limit)),
            )
            rows = list(cursor.fetchall())
            rows.reverse()
            for row in rows:
                row["metadata"] = self._parse_json(row.get("metadata"), {})
                row["retrieved_chunk_ids"] = self._parse_json(row.get("retrieved_chunk_ids"), [])
                created_at = row.get("created_at")
                if created_at and hasattr(created_at, "isoformat"):
                    row["created_at"] = created_at.isoformat()
                self._push_cached_message(bid, conversation_id, row)
            return rows
        finally:
            conn.close()

    def get_user_profile(self, bid, user_id):
        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                SELECT profile, profile_version, updated_at
                FROM rag_user_profiles
                WHERE bid = %s AND user_id = %s
                """,
                (str(bid), str(user_id)),
            )
            row = cursor.fetchone()
            if not row:
                return {
                    "user_id": str(user_id),
                    "profile": {
                        "interests": {},
                        "traits": {},
                        "last_topics": [],
                    },
                    "profile_version": 1,
                }

            profile = self._parse_json(row.get("profile"), {})
            return {
                "user_id": str(user_id),
                "profile": profile,
                "profile_version": row.get("profile_version", 1),
                "updated_at": row.get("updated_at").isoformat() if row.get("updated_at") else None,
            }
        finally:
            conn.close()

    def _extract_topics(self, text, max_terms=8):
        words = re.findall(r"[a-zA-Z0-9]{3,}", (text or "").lower())
        filtered = [w for w in words if w not in self.STOPWORDS]
        counts = Counter(filtered)
        return [term for term, _ in counts.most_common(max_terms)]

    def update_user_profile(self, bid, user_id, user_text, profile_updates=None):
        existing = self.get_user_profile(bid, user_id)
        profile = existing.get("profile") or {}
        interests = profile.get("interests") or {}
        traits = profile.get("traits") or {}
        last_topics = profile.get("last_topics") or []

        new_topics = self._extract_topics(user_text)
        for topic in new_topics:
            interests[topic] = int(interests.get(topic, 0)) + 1

        last_topics = (new_topics + last_topics)[:20]

        updates = profile_updates or {}
        for key, value in updates.get("traits", {}).items():
            traits[str(key)] = value

        profile["interests"] = interests
        profile["traits"] = traits
        profile["last_topics"] = last_topics

        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                INSERT INTO rag_user_profiles (bid, user_id, profile, profile_version)
                VALUES (%s, %s, %s, %s)
                ON DUPLICATE KEY UPDATE
                    profile = VALUES(profile),
                    profile_version = profile_version + 1,
                    updated_at = CURRENT_TIMESTAMP
                """,
                (
                    str(bid),
                    str(user_id),
                    json.dumps(profile, ensure_ascii=True),
                    existing.get("profile_version", 1),
                ),
            )
            conn.commit()
        finally:
            conn.close()

        return profile

    def _chunk_transcript(self, transcript, speaker_segments, max_chars=900):
        transcript = (transcript or "").strip()
        chunks = []
        if not transcript:
            return chunks

        segments = self._parse_json(speaker_segments, [])
        if isinstance(segments, dict):
            segments = segments.get("segments", [])
        if not isinstance(segments, list):
            segments = []

        if segments:
            current_speaker = None
            current_text = []
            chunk_index = 0
            segment_start = None

            def flush():
                nonlocal current_text, chunk_index, segment_start, current_speaker
                text = " ".join([str(t).strip() for t in current_text if str(t).strip()]).strip()
                if text:
                    chunks.append(
                        {
                            "chunk_id": f"seg-{chunk_index}",
                            "chunk_index": chunk_index,
                            "text": text,
                            "metadata": {
                                "speaker": current_speaker or "unknown",
                                "segment_start": segment_start,
                                "segment_count": len(current_text),
                            },
                        }
                    )
                    chunk_index += 1
                current_text = []
                segment_start = None

            for segment in segments:
                if not isinstance(segment, dict):
                    continue
                speaker = str(segment.get("speaker") or segment.get("speaker_label") or "unknown")
                seg_text = str(segment.get("text") or segment.get("transcript") or "").strip()
                if not seg_text:
                    continue
                if segment_start is None:
                    segment_start = segment.get("start_time") or segment.get("start")

                projected = (" ".join(current_text + [seg_text])).strip()
                if current_speaker is None:
                    current_speaker = speaker
                if speaker != current_speaker or len(projected) > max_chars:
                    flush()
                    current_speaker = speaker
                    segment_start = segment.get("start_time") or segment.get("start")
                current_text.append(seg_text)
            flush()
            if chunks:
                return chunks

        # Fallback: paragraph-ish chunks from raw transcript lines
        words = transcript.split()
        chunk_words = []
        chunk_index = 0
        for word in words:
            chunk_words.append(word)
            if sum(len(w) + 1 for w in chunk_words) >= max_chars:
                chunks.append(
                    {
                        "chunk_id": f"txt-{chunk_index}",
                        "chunk_index": chunk_index,
                        "text": " ".join(chunk_words).strip(),
                        "metadata": {"speaker": "unknown"},
                    }
                )
                chunk_index += 1
                chunk_words = []
        if chunk_words:
            chunks.append(
                {
                    "chunk_id": f"txt-{chunk_index}",
                    "chunk_index": chunk_index,
                    "text": " ".join(chunk_words).strip(),
                    "metadata": {"speaker": "unknown"},
                }
            )
        return chunks

    def backfill_transcripts(
        self,
        bid,
        presales_only=True,
        limit=1000,
        overwrite_existing=False,
        source_type="call_transcript",
        callids=None,
    ):
        bid = self._safe_identifier(str(bid))
        calls_table = f"{bid}_calls"
        raw_calls_table = f"{bid}_raw_calls"
        sarvam_table = f"{bid}_sarvamresponse"
        analytics_table = f"{bid}_callanalytics"
        bant_table = f"{bid}_bant"

        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            if not self._table_exists(cursor, sarvam_table):
                return {
                    "bid": bid,
                    "processed_calls": 0,
                    "ingested_documents": 0,
                    "ingested_chunks": 0,
                    "skipped": 0,
                    "error": f"Table {sarvam_table} not found",
                }

            has_calls = self._table_exists(cursor, calls_table)
            has_raw_calls = self._table_exists(cursor, raw_calls_table)
            has_analytics = self._table_exists(cursor, analytics_table)
            has_bant = self._table_exists(cursor, bant_table)
            source_calls_table = calls_table if has_calls else raw_calls_table if has_raw_calls else None

            if not source_calls_table:
                return {
                    "bid": bid,
                    "processed_calls": 0,
                    "ingested_documents": 0,
                    "ingested_chunks": 0,
                    "skipped": 0,
                    "error": f"No source calls table found for {bid}",
                }

            source_columns = self._table_columns(cursor, source_calls_table)
            sarvam_columns = self._table_columns(cursor, sarvam_table)
            analytics_columns = self._table_columns(cursor, analytics_table) if has_analytics else set()
            bant_columns = self._table_columns(cursor, bant_table) if has_bant else set()
            col_agentname = "c.agentname" if "agentname" in source_columns else "NULL"
            col_customer = "c.customer_callinfo" if "customer_callinfo" in source_columns else "NULL"
            col_call_start = "c.call_starttime" if "call_starttime" in source_columns else "NULL"
            col_direction = "c.direction" if "direction" in source_columns else "NULL"
            col_sales_intent = "c.sales_intent" if "sales_intent" in source_columns else "NULL"
            col_call_purpose = "c.call_purpose" if "call_purpose" in source_columns else "NULL"
            order_expr = "c.call_starttime DESC" if "call_starttime" in source_columns else "s.created_at DESC"
            col_speaker_segments = "s.speaker_segments" if "speaker_segments" in sarvam_columns else "NULL"
            col_duration = "s.duration" if "duration" in sarvam_columns else "NULL"
            col_language = "s.language" if "language" in sarvam_columns else "NULL"
            col_analytics_summary = "a.summary" if "summary" in analytics_columns else "NULL"
            col_objection_type = "a.objection_type" if "objection_type" in analytics_columns else "NULL"
            col_sentiment = "a.sentiment" if "sentiment" in analytics_columns else "NULL"
            col_quality_score = "a.quality_score" if "quality_score" in analytics_columns else "NULL"
            col_bant_profile = "b.profile_json" if "profile_json" in bant_columns else "NULL"
            col_bant_summary = "b.profile_summary" if "profile_summary" in bant_columns else "NULL"

            where_parts = ["s.transcript IS NOT NULL", "s.transcript != ''"]
            params = []
            normalized_callids = [str(c).strip() for c in (callids or []) if str(c).strip()]
            if normalized_callids:
                placeholders = ",".join(["%s"] * len(normalized_callids))
                where_parts.append(f"c.callid IN ({placeholders})")
                params.extend(normalized_callids)
            if (
                presales_only
                and source_calls_table == calls_table
                and ("sales_intent" in source_columns or "call_purpose" in source_columns)
            ):
                where_parts.append(
                    "("
                    "LOWER(COALESCE(c.sales_intent,'')) IN ('high','medium','pre-sales','presales','pre_sales') "
                    "OR LOWER(COALESCE(c.call_purpose,'')) LIKE '%%sale%%' "
                    "OR LOWER(COALESCE(c.call_purpose,'')) LIKE '%%prospect%%' "
                    "OR LOWER(COALESCE(c.call_purpose,'')) LIKE '%%demo%%'"
                    ")"
                )

            query = f"""
                SELECT
                    c.callid,
                    {col_agentname} AS agentname,
                    {col_customer} AS customer_callinfo,
                    {col_call_start} AS call_starttime,
                    {col_direction} AS direction,
                    {col_sales_intent} AS sales_intent,
                    {col_call_purpose} AS call_purpose,
                    s.transcript,
                    {col_speaker_segments} AS speaker_segments,
                    {col_duration} AS duration,
                    {col_language} AS language,
                    { f"{col_analytics_summary} AS analytics_summary, {col_objection_type} AS objection_type, {col_sentiment} AS sentiment, {col_quality_score} AS quality_score" if has_analytics else "NULL AS analytics_summary, NULL AS objection_type, NULL AS sentiment, NULL AS quality_score" }
                    ,
                    { f"{col_bant_profile} AS bant_profile, {col_bant_summary} AS bant_summary" if has_bant else "NULL AS bant_profile, NULL AS bant_summary" }
                FROM `{source_calls_table}` c
                JOIN `{sarvam_table}` s ON s.callid COLLATE utf8mb4_unicode_ci = c.callid COLLATE utf8mb4_unicode_ci
                { f"LEFT JOIN `{analytics_table}` a ON a.callid COLLATE utf8mb4_unicode_ci = c.callid COLLATE utf8mb4_unicode_ci" if has_analytics else "" }
                { f"LEFT JOIN `{bant_table}` b ON b.callid COLLATE utf8mb4_unicode_ci = c.callid COLLATE utf8mb4_unicode_ci" if has_bant else "" }
                WHERE {" AND ".join(where_parts)}
                ORDER BY {order_expr}
                LIMIT %s
            """
            params.append(int(limit))
            cursor.execute(query, params)
            rows = cursor.fetchall()
        finally:
            conn.close()

        if not rows:
            return {
                "bid": bid,
                "processed_calls": 0,
                "ingested_documents": 0,
                "ingested_chunks": 0,
                "skipped": 0,
            }

        documents = []
        skipped = 0
        processed_calls = 0

        for row in rows:
            callid = str(row.get("callid"))
            source_id = f"{bid}-call-{callid}"
            transcript = row.get("transcript") or ""
            speaker_segments = row.get("speaker_segments")
            chunks = self._chunk_transcript(transcript, speaker_segments)
            if not chunks:
                skipped += 1
                continue

            analytics_summary = row.get("analytics_summary")
            quality_score = row.get("quality_score")
            sentiment = row.get("sentiment")
            objection_type = row.get("objection_type")
            call_purpose = row.get("call_purpose")
            sales_intent = row.get("sales_intent")
            bant_profile_raw = row.get("bant_profile")
            bant_profile = self._parse_json(bant_profile_raw, bant_profile_raw if isinstance(bant_profile_raw, dict) else None)
            bant_summary = row.get("bant_summary")

            analysis_lines = []
            if analytics_summary:
                analysis_lines.append(f"Summary: {analytics_summary}")
            if quality_score is not None:
                analysis_lines.append(f"Quality Score: {quality_score}")
            if sentiment:
                analysis_lines.append(f"Sentiment: {sentiment}")
            if objection_type:
                analysis_lines.append(f"Objection Type: {objection_type}")
            if call_purpose:
                analysis_lines.append(f"Call Purpose: {call_purpose}")
            if sales_intent:
                analysis_lines.append(f"Sales Intent: {sales_intent}")
            if bant_summary:
                analysis_lines.append(f"Customer Profile Summary: {bant_summary}")
            if bant_profile:
                try:
                    analysis_lines.append(f"Customer Profile (BANT): {json.dumps(self._json_safe(bant_profile), ensure_ascii=True)}")
                except Exception:
                    pass

            if analysis_lines:
                analysis_text = "CALL ANALYSIS CONTEXT\n" + "\n".join(analysis_lines)
                chunks.insert(
                    0,
                    {
                        "chunk_id": "analysis-0",
                        "chunk_index": 0,
                        "text": analysis_text,
                        "metadata": {
                            "speaker": "system",
                            "chunk_type": "analysis_context",
                        },
                    },
                )
                for i, chunk in enumerate(chunks):
                    chunk["chunk_index"] = i

            if not overwrite_existing:
                check_conn = self.get_connection()
                try:
                    check_cursor = check_conn.cursor()
                    check_cursor.execute(
                        "SELECT id FROM rag_documents WHERE bid = %s AND source_id = %s LIMIT 1",
                        (bid, source_id),
                    )
                    if check_cursor.fetchone():
                        skipped += 1
                        continue
                finally:
                    check_conn.close()

            chunk_texts = [c["text"] for c in chunks]
            embeddings = self._generate_embeddings(chunk_texts)

            final_chunks = []
            for chunk, emb in zip(chunks, embeddings):
                if not emb:
                    continue
                final_chunks.append(
                    {
                        "chunk_id": chunk["chunk_id"],
                        "chunk_index": chunk["chunk_index"],
                        "text": chunk["text"],
                        "embedding": emb,
                        "metadata": chunk.get("metadata") or {},
                    }
                )

            if not final_chunks:
                skipped += 1
                continue

            metadata = {
                "bid": bid,
                "callid": callid,
                "agentname": row.get("agentname"),
                "customer_callinfo": row.get("customer_callinfo"),
                "call_starttime": row.get("call_starttime").isoformat() if row.get("call_starttime") else None,
                "direction": row.get("direction"),
                "sales_intent": sales_intent,
                "call_purpose": call_purpose,
                "language": row.get("language"),
                "duration": row.get("duration"),
                "summary": analytics_summary,
                "objection_type": objection_type,
                "sentiment": sentiment,
                "quality_score": quality_score,
                "bant_summary": bant_summary,
                "bant_profile": bant_profile,
            }

            documents.append(
                {
                    "source_id": source_id,
                    "title": f"Call {callid}",
                    "source_type": source_type,
                    "source_uri": f"call://{bid}/{callid}",
                    "metadata": metadata,
                    "chunks": final_chunks,
                }
            )
            processed_calls += 1

        result = self.ingest_documents(bid, documents) if documents else {"documents": 0, "chunks": 0, "skipped": 0}
        return {
            "bid": bid,
            "processed_calls": processed_calls,
            "ingested_documents": result.get("documents", 0),
            "ingested_chunks": result.get("chunks", 0),
            "skipped": skipped + result.get("skipped", 0),
        }

    def _generate_single_embedding(self, text):
        cache_key = self._embed_cache_key(text)
        cached = self._cache_get(cache_key)
        if cached:
            return self._ensure_embedding_list(cached)

        if not self.bedrock_runtime:
            return None

        original_text = str(text or "")
        candidates = [original_text]
        if len(original_text) > 12000:
            candidates = [original_text[:12000], original_text[:8000], original_text[:4000]]

        for idx, candidate in enumerate(candidates):
            try:
                body = {"inputText": candidate}
                response = self.bedrock_runtime.invoke_model(
                    modelId=self.rag_embedding_model,
                    body=json.dumps(body),
                    accept="application/json",
                    contentType="application/json",
                )
                payload = json.loads(response["body"].read())
                embedding = payload.get("embedding")
                if not embedding:
                    return None
                vector = self._ensure_embedding_list(embedding)
                if vector:
                    self._cache_set(cache_key, vector, self.rag_embed_cache_ttl)
                return vector
            except Exception as exc:
                # Bedrock Titan can reject long inputs; retry with smaller candidates.
                msg = str(exc)
                token_error = ("Too many input tokens" in msg) or ("ValidationException" in msg)
                if token_error and idx < len(candidates) - 1:
                    logger.warning("Embedding input too large; retrying with shorter text (%s chars)", len(candidates[idx + 1]))
                    continue
                raise
        return None

    def _generate_query_embedding(self, text):
        return self._generate_single_embedding(text)

    def _generate_embeddings(self, texts):
        vectors = []
        for text in texts:
            vector = self._generate_single_embedding(text)
            if not vector:
                vectors.append(None)
            else:
                vectors.append(vector)
        return vectors

    def retrieve_chunks(self, bid, query_embedding, top_k=None, min_similarity=None, candidate_limit=1500):
        vector = self._ensure_embedding_list(query_embedding)
        if not vector:
            raise ValueError("query_embedding is required and must be a numeric array")

        q_norm = self._vector_norm(vector)
        if q_norm == 0:
            raise ValueError("query_embedding norm cannot be zero")

        top_k = int(top_k or self.rag_top_k)
        min_similarity = float(min_similarity if min_similarity is not None else self.rag_similarity_threshold)
        cache_key = self._retrieval_cache_key(bid, vector, top_k, min_similarity)
        cached = self._cache_get(cache_key)
        if cached:
            return cached

        if self.qdrant_client:
            try:
                collection_name = self._qdrant_collection_name(bid)
                hits = self.qdrant_client.search(
                    collection_name=collection_name,
                    query_vector=vector,
                    limit=top_k,
                    with_payload=True,
                    with_vectors=False,
                )
                scored = []
                for hit in hits:
                    score = float(getattr(hit, "score", 0.0))
                    if score < min_similarity:
                        continue
                    payload = getattr(hit, "payload", {}) or {}
                    scored.append(
                        {
                            "chunk_id": payload.get("chunk_id"),
                            "document_id": payload.get("document_id"),
                            "source_id": payload.get("source_id"),
                            "title": payload.get("title"),
                            "source_type": payload.get("source_type"),
                            "source_uri": payload.get("source_uri"),
                            "content": payload.get("content"),
                            "metadata": payload.get("metadata") or {},
                            "similarity": round(score, 6),
                        }
                    )
                scored.sort(key=lambda x: x["similarity"], reverse=True)
                result = scored[:top_k]
                self._cache_set(cache_key, result, self.rag_retrieval_cache_ttl)
                if result:
                    return result
            except Exception as exc:
                logger.warning("Qdrant search failed for bid %s; falling back to MySQL vectors: %s", bid, exc)

        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute(
                """
                SELECT
                    c.id,
                    c.chunk_id,
                    c.document_id,
                    c.content,
                    c.metadata,
                    c.embedding,
                    c.embedding_dim,
                    c.embedding_norm,
                    d.source_id,
                    d.title,
                    d.source_type,
                    d.source_uri
                FROM rag_chunks c
                JOIN rag_documents d ON d.id = c.document_id
                WHERE c.bid = %s AND d.is_active = TRUE
                ORDER BY c.updated_at DESC
                LIMIT %s
                """,
                (str(bid), int(candidate_limit)),
            )
            rows = cursor.fetchall()
        finally:
            conn.close()

        scored = []
        for row in rows:
            if row.get("embedding_dim") != len(vector):
                continue
            emb = self._ensure_embedding_list(row.get("embedding"))
            if not emb:
                continue

            denom = q_norm * float(row.get("embedding_norm") or 0)
            if denom == 0:
                continue

            similarity = self._dot_product(vector, emb) / denom
            if similarity < min_similarity:
                continue

            scored.append(
                {
                    "chunk_id": row["chunk_id"],
                    "document_id": row["document_id"],
                    "source_id": row.get("source_id"),
                    "title": row.get("title"),
                    "source_type": row.get("source_type"),
                    "source_uri": row.get("source_uri"),
                    "content": row.get("content"),
                    "metadata": self._parse_json(row.get("metadata"), {}),
                    "similarity": round(similarity, 6),
                }
            )

        scored.sort(key=lambda x: x["similarity"], reverse=True)
        result = scored[:top_k]
        self._cache_set(cache_key, result, self.rag_retrieval_cache_ttl)
        return result

    def _build_prompt(self, message, retrieved_chunks, profile, memory_messages):
        context_blocks = []
        for idx, chunk in enumerate(retrieved_chunks, start=1):
            context_blocks.append(
                f"[Source {idx}] title={chunk.get('title') or 'Untitled'} score={chunk.get('similarity')}\n"
                f"{chunk.get('content')}"
            )

        memory_lines = []
        for m in memory_messages[-self.rag_memory_messages:]:
            role = m.get("role", "user")
            content = m.get("content", "")
            memory_lines.append(f"{role}: {content}")

        system_prompt = (
            "You are an enterprise RAG assistant. Use only provided context when factual claims are needed. "
            "If context is missing, state uncertainty and ask a targeted follow-up question. "
            "Keep response concise and actionable."
        )
        memory_text = "\n".join(memory_lines) if memory_lines else "None"
        context_text = "\n\n".join(context_blocks) if context_blocks else "None"

        prompt = (
            f"SYSTEM:\n{system_prompt}\n\n"
            f"USER PROFILE:\n{json.dumps(profile, ensure_ascii=True)}\n\n"
            f"RECENT MEMORY:\n{memory_text}\n\n"
            f"RETRIEVED CONTEXT:\n{context_text}\n\n"
            f"USER QUESTION:\n{message}\n\n"
            "Answer with direct guidance."
        )
        return prompt

    def _invoke_bedrock_chat_model(self, prompt, model_name=None, runtime_config=None):
        if not self.bedrock_runtime:
            logger.warning("Bedrock runtime is not configured")
            return None
        runtime_config = runtime_config or {}
        request_body = {
            "messages": [
                {
                    "role": "user",
                    "content": [{"text": prompt}],
                }
            ],
            "inferenceConfig": {
                "max_new_tokens": int(runtime_config.get("max_tokens", self.config.get("RAG_MAX_TOKENS", 600))),
                "temperature": float(runtime_config.get("temperature", self.config.get("RAG_TEMPERATURE", 0.2))),
                "top_p": float(runtime_config.get("top_p", self.config.get("RAG_TOP_P", 0.9))),
            },
        }

        try:
            response = self.bedrock_runtime.invoke_model(
                modelId=str(model_name or self.rag_chat_model),
                body=json.dumps(request_body),
                accept="application/json",
                contentType="application/json",
            )
            payload = json.loads(response["body"].read())
        except Exception as exc:
            logger.warning("Bedrock chat request failed: %s", exc)
            return None

        output = payload.get("output", {})
        message = output.get("message", {})
        content = message.get("content", [])

        texts = []
        for item in content:
            text = item.get("text") if isinstance(item, dict) else None
            if text:
                texts.append(text)

        return "\n".join(texts).strip() if texts else None

    def _invoke_ollama_chat_model(self, prompt, model_name=None, runtime_config=None):
        runtime_config = runtime_config or {}
        model = str(model_name or self.rag_chat_model or "").strip()
        if not model:
            return None

        payload = {
            "model": model,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": float(runtime_config.get("temperature", self.config.get("RAG_TEMPERATURE", 0.2))),
                "top_p": float(runtime_config.get("top_p", self.config.get("RAG_TOP_P", 0.9))),
                "num_predict": int(runtime_config.get("max_tokens", self.config.get("RAG_MAX_TOKENS", 600))),
            },
        }
        try:
            response = requests.post(
                f"{self.ollama_base_url}/api/generate",
                json=payload,
                timeout=self.ollama_timeout_seconds,
            )
            response.raise_for_status()
            data = response.json() if response.content else {}
            answer = data.get("response")
            return str(answer).strip() if answer else None
        except Exception as exc:
            logger.warning("Ollama chat request failed: %s", exc)
            return None

    def _invoke_chat_model(self, prompt, provider=None, model_name=None, runtime_config=None):
        selected_provider = str(provider or self.rag_chat_provider or "bedrock").strip().lower()
        runtime_config = runtime_config or {}

        if selected_provider == "ollama":
            return self._invoke_ollama_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)
        if selected_provider == "bedrock":
            return self._invoke_bedrock_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)

        # auto mode: try provider from env, then bedrock, then ollama
        if selected_provider == "auto":
            default_provider = str(self.rag_chat_provider or "bedrock").strip().lower()
            if default_provider == "ollama":
                answer = self._invoke_ollama_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)
                if answer:
                    return answer
                return self._invoke_bedrock_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)
            answer = self._invoke_bedrock_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)
            if answer:
                return answer
            return self._invoke_ollama_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)

        # fallback unknown provider -> try bedrock then ollama
        answer = self._invoke_bedrock_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)
        if answer:
            return answer
        return self._invoke_ollama_chat_model(prompt, model_name=model_name, runtime_config=runtime_config)

    def _fallback_answer(self, message, retrieved_chunks):
        if not retrieved_chunks:
            return (
                "I do not have enough context in the knowledge base to answer this accurately yet. "
                "Please ingest relevant documents or include a query embedding." 
            )

        top = retrieved_chunks[0]
        snippet = top.get("content", "")[:500]
        return (
            "Using the closest stored context, here is what I found:\n\n"
            f"{snippet}\n\n"
            "If you want a stronger answer, provide a query embedding or configure Bedrock credentials for model generation."
        )

    def query(self, bid, user_id, message, query_embedding=None, conversation_id=None, top_k=None, min_similarity=None, metadata=None, profile_updates=None):
        if not message or not str(message).strip():
            raise ValueError("message is required")
        conversation_id = self._resolve_conversation_id(
            bid=bid,
            user_id=user_id,
            metadata=metadata or {},
            explicit_id=conversation_id,
        )

        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            conversation_id = self._get_or_create_conversation(
                cursor,
                bid=bid,
                user_id=user_id,
                conversation_id=conversation_id,
                metadata=metadata,
            )
            conn.commit()
        finally:
            conn.close()
        self._save_conversation_resume(bid, user_id, metadata or {}, conversation_id)

        if query_embedding is None:
            query_embedding = self._generate_query_embedding(message)

        if query_embedding is None:
            raise ValueError(
                "query_embedding is required unless Bedrock embedding credentials/model are configured"
            )

        retrieved_chunks = self.retrieve_chunks(
            bid=bid,
            query_embedding=query_embedding,
            top_k=top_k,
            min_similarity=min_similarity,
        )

        memory = self.get_conversation_messages(bid, conversation_id, limit=self.rag_memory_messages)
        profile = self.update_user_profile(bid, user_id, user_text=message, profile_updates=profile_updates)

        self.save_message(
            bid=bid,
            conversation_id=conversation_id,
            user_id=user_id,
            role="user",
            content=message,
            metadata=metadata,
            retrieved_chunk_ids=[c["chunk_id"] for c in retrieved_chunks],
        )

        prompt = self._build_prompt(
            message=message,
            retrieved_chunks=retrieved_chunks,
            profile=profile,
            memory_messages=memory,
        )
        if isinstance(metadata, dict):
            agent_instruction = str(metadata.get("agent_instruction") or "").strip()
            if agent_instruction:
                prompt = (
                    f"AGENT INSTRUCTION:\n{agent_instruction}\n\n"
                    f"{prompt}"
                )

        answer = None
        model_provider = None
        model_name = None
        runtime_cfg = None
        if isinstance(metadata, dict):
            model_provider = metadata.get("llm_provider")
            model_name = metadata.get("llm_model_name")
            runtime_cfg = metadata.get("llm_runtime_config")
        try:
            answer = self._invoke_chat_model(
                prompt,
                provider=model_provider,
                model_name=model_name,
                runtime_config=runtime_cfg,
            )
        except Exception as exc:
            logger.warning("Bedrock chat call failed, using fallback: %s", exc)

        if not answer:
            answer = self._fallback_answer(message, retrieved_chunks)

        self.save_message(
            bid=bid,
            conversation_id=conversation_id,
            user_id=user_id,
            role="assistant",
            content=answer,
            metadata={"top_k": top_k or self.rag_top_k},
            retrieved_chunk_ids=[c["chunk_id"] for c in retrieved_chunks],
        )

        return {
            "conversation_id": conversation_id,
            "answer": answer,
            "retrieved_context": [
                {
                    "chunk_id": c["chunk_id"],
                    "source_id": c.get("source_id"),
                    "title": c.get("title"),
                    "similarity": c.get("similarity"),
                    "metadata": c.get("metadata"),
                }
                for c in retrieved_chunks
            ],
            "profile": profile,
            "timestamp": datetime.utcnow().isoformat() + "Z",
        }
