"""
AI agent runner for post-call analysis.

Loads ``business_agent_config`` records for a given bid from the database,
builds prompts using the call transcript and speaker segments, calls the
configured LLM (AWS Bedrock or Ollama), and returns an aggregated analysis
dict.

Each agent produces a JSON blob stored under its ``agent_name`` key.  Two
top-level shortcuts — ``quality_score`` and ``summary`` — are promoted from
the first agent that emits them so the call records list can display scores
without parsing per-agent JSON.
"""
from __future__ import annotations

import json
import logging
import os
import re
from typing import Any, Dict, List, Optional


logger = logging.getLogger(__name__)

# ── AWS defaults (override via env or per-agent model_id) ────────────────────
_AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
_AWS_KEY = os.getenv("AWS_ACCESS_KEY_ID")
_AWS_SECRET = os.getenv("AWS_SECRET_ACCESS_KEY")
_DEFAULT_BEDROCK_MODEL = os.getenv("AWS_NOVA_MODEL", "amazon.nova-lite-v1:0")


# ── Prompt helpers ────────────────────────────────────────────────────────────

def _format_transcript(transcript: str, speaker_segments: List[Dict]) -> str:
    """Return a readable conversation string from speaker_segments, or raw transcript."""
    if speaker_segments:
        lines = []
        for seg in speaker_segments:
            role = seg.get("role", seg.get("speaker", "Speaker"))
            text = (seg.get("text") or "").strip()
            if text:
                lines.append(f"[{role.upper()}]: {text}")
        if lines:
            return "\n".join(lines)
    return transcript or ""


def _render_template(template: str, context: Dict[str, Any]) -> str:
    """Replace ``{key}`` placeholders in *template* with values from *context*."""
    for key, value in context.items():
        template = template.replace("{" + key + "}", str(value or ""))
    return template


def _build_scoring_parameters_block(scoring_parameters: List[Dict]) -> str:
    """Format a list of scoring parameter dicts into a structured prompt block.

    Only parameters with ``enabled != False`` are included.
    """
    active = [p for p in scoring_parameters if p.get("enabled") is not False]
    if not active:
        return ""

    lines = ["SCORING PARAMETERS TO EVALUATE:\n"]
    total = sum(int(p.get("max_score") or 0) for p in active)
    lines.append(f"Total possible score: {total} points\n")

    # Group by parameter_group
    groups: Dict[str, List[Dict]] = {}
    for p in active:
        g = p.get("parameter_group") or "General"
        groups.setdefault(g, []).append(p)

    for group_name, params in groups.items():
        lines.append(f"\n[{group_name}]")
        for p in params:
            ptype = p.get("parameter_type") or "Required"
            fatal = " ⚠ FATAL" if p.get("is_fatal") else ""
            lines.append(f"  • {p['parameter_name']} ({ptype}{fatal}) — max {p.get('max_score', 0)} pts")
            if p.get("check_description"):
                lines.append(f"    Check: {p['check_description']}")
            if p.get("detailed_description"):
                lines.append(f"    Detail: {p['detailed_description']}")
            if p.get("sample_utterances"):
                lines.append(f"    Example: {p['sample_utterances']}")

    lines.append(
        "\nFor each parameter return: score (0 to max_score), evidence (exact quote), "
        "reasoning, and applicable (true/false). "
        "If a FATAL parameter scores 0, set overall quality_score to 0."
    )
    return "\n".join(lines)


def _extract_json(text: str) -> Any:
    """Extract the first JSON object or array from *text*."""
    # Try whole text first
    try:
        return json.loads(text.strip())
    except Exception:
        pass
    # Find the first { or [ and parse from there
    for start_char, end_char in (("{", "}"), ("[", "]")):
        idx = text.find(start_char)
        if idx == -1:
            continue
        # Walk forward to find matching close
        depth = 0
        in_str = False
        escape = False
        for i, ch in enumerate(text[idx:], start=idx):
            if escape:
                escape = False
                continue
            if ch == "\\" and in_str:
                escape = True
                continue
            if ch == '"' and not escape:
                in_str = not in_str
            if not in_str:
                if ch == start_char:
                    depth += 1
                elif ch == end_char:
                    depth -= 1
                    if depth == 0:
                        try:
                            return json.loads(text[idx:i + 1])
                        except Exception:
                            break
    # Fallback: extract with regex (less reliable)
    match = re.search(r"\{[\s\S]+\}", text)
    if match:
        try:
            return json.loads(match.group(0))
        except Exception:
            pass
    return None


# ── LLM call implementations ──────────────────────────────────────────────────

def _call_bedrock(
    model_id: str,
    system_prompt: str,
    user_message: str,
    temperature: float = 0.1,
    max_tokens: int = 4096,
) -> str:
    """Invoke an AWS Bedrock Converse-API model and return the text response."""
    import boto3

    client = boto3.client(
        service_name="bedrock-runtime",
        region_name=_AWS_REGION,
        aws_access_key_id=_AWS_KEY,
        aws_secret_access_key=_AWS_SECRET,
    )

    messages = [{"role": "user", "content": [{"text": user_message}]}]
    system = [{"text": system_prompt}] if system_prompt else []

    response = client.converse(
        modelId=model_id,
        messages=messages,
        system=system,
        inferenceConfig={"temperature": temperature, "maxTokens": max_tokens},
    )

    output = response.get("output", {})
    content = output.get("message", {}).get("content", [])
    for block in content:
        if "text" in block:
            return block["text"]
    return ""


def _call_ollama(
    model_id: str,
    system_prompt: str,
    user_message: str,
    temperature: float = 0.1,
    max_tokens: int = 4096,
) -> str:
    """Call a local Ollama model and return the response text."""
    import requests as req_lib

    base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
    payload = {
        "model": model_id,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message},
        ],
        "stream": False,
        "options": {"temperature": temperature, "num_predict": max_tokens},
    }
    resp = req_lib.post(f"{base_url}/api/chat", json=payload, timeout=120)
    resp.raise_for_status()
    return resp.json().get("message", {}).get("content", "")


# ── AgentRunner ────────────────────────────────────────────────────────────────

class AgentRunner:
    """Run all enabled AI agents for a bid against a given call transcript.

    Parameters
    ----------
    db_handler:
        Shared ``DatabaseHandler`` instance (used to load agent configs).
    """

    def __init__(self, db_handler) -> None:
        self._db = db_handler

    def run(
        self,
        bid: str,
        callid: str,
        transcript: str,
        speaker_segments: Optional[List[Dict]] = None,
        call_metadata: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Run all enabled agents and return an aggregated analysis dict.

        Returns
        -------
        dict with keys:
            - ``<agent_name>``: parsed JSON output for each agent
            - ``quality_score``: float, promoted from the first agent that emits it
            - ``summary``: str, promoted from the first agent that emits it
            - ``agents_run``: list of agent names that were executed
            - ``agents_failed``: list of agent names that raised an error
        """
        configs = self._db.get_agent_configs(bid)
        if not configs:
            logger.info("[%s][%s] No agent configs found for bid", bid, callid)
            return {"agents_run": [], "agents_failed": []}

        formatted = _format_transcript(transcript, speaker_segments or [])
        context: Dict[str, Any] = {
            "transcript": formatted,
            "callid": callid,
            "bid": bid,
            **(call_metadata or {}),
        }

        result: Dict[str, Any] = {"agents_run": [], "agents_failed": []}
        quality_score_set = False
        summary_set = False

        for cfg in configs:
            if not cfg.get("agent_enabled", True):
                continue

            agent_name: str = cfg["agent_name"]
            provider: str = (cfg.get("model_provider") or "bedrock").lower()
            model_id: str = cfg.get("model_id") or _DEFAULT_BEDROCK_MODEL
            system_prompt: str = cfg.get("system_prompt") or ""
            user_template: str = cfg.get("user_prompt_template") or "{transcript}"
            temperature: float = float(cfg.get("temperature") or 0.1)
            max_tokens: int = int(cfg.get("max_tokens") or 4096)

            # Inject scoring parameters block if the agent has them configured
            scoring_parameters: List[Dict] = []
            raw_cfg = cfg.get("runtime_config")
            if isinstance(raw_cfg, str):
                try:
                    raw_cfg = json.loads(raw_cfg)
                except Exception:
                    raw_cfg = {}
            if isinstance(raw_cfg, dict):
                scoring_parameters = raw_cfg.get("scoring_parameters") or []

            scoring_block = _build_scoring_parameters_block(scoring_parameters)
            context_with_params = {**context, "scoring_parameters": scoring_block}
            user_message = _render_template(user_template, context_with_params)
            # If template doesn't reference {scoring_parameters} but params exist, append them
            if scoring_block and "{scoring_parameters}" not in user_template:
                user_message = f"{user_message}\n\n{scoring_block}"

            logger.info(
                "[%s][%s] Running agent '%s' via %s/%s (scoring_params=%d)",
                bid, callid, agent_name, provider, model_id, len(scoring_parameters),
            )

            try:
                if provider == "bedrock":
                    raw_text = _call_bedrock(
                        model_id, system_prompt, user_message, temperature, max_tokens
                    )
                elif provider == "ollama":
                    raw_text = _call_ollama(
                        model_id, system_prompt, user_message, temperature, max_tokens
                    )
                else:
                    raise ValueError(f"Unknown model provider: {provider}")

                parsed = _extract_json(raw_text)
                if parsed is None:
                    # Store raw text if JSON extraction failed
                    parsed = {"raw_output": raw_text}

                result[agent_name] = parsed
                result["agents_run"].append(agent_name)

                # Promote top-level shortcuts from the first agent that provides them
                if isinstance(parsed, dict):
                    if not quality_score_set and "quality_score" in parsed:
                        result["quality_score"] = float(parsed["quality_score"])
                        quality_score_set = True
                    if not summary_set and "summary" in parsed:
                        result["summary"] = str(parsed["summary"])
                        summary_set = True

            except Exception as exc:
                logger.error(
                    "[%s][%s] Agent '%s' failed: %s",
                    bid, callid, agent_name, exc, exc_info=True,
                )
                result["agents_failed"].append(agent_name)

        return result
