"""
Lead Insights Generator
=======================
Generates two types of AI-powered insights:

1. ``generate_lead_insights(db, bid, lead_phone)``
   Unified cross-call insights for the Insights tab:
   - objections (type, description, how handled, customer satisfied)
   - BANT synthesised across all calls
   - data_capture (competitors, products discussed)
   - path_to_conversion (previous summary bullets + 3 action steps with scripts)

2. ``generate_rich_summary(transcript_text, speaker_segments, call_metadata)``
   Detailed per-call structured summary for the Conversations → Summary tab:
   - call_summary  (sections with title + bullets)
   - bant_insights  (budget / authority / needs / timeline)
   - competitors
   - objection / objection_handling / objection_categories
   - next_steps
"""
from __future__ import annotations

import json
import logging
import os
import re
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)

# ── model defaults ────────────────────────────────────────────────────────────
_AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
_AWS_KEY = os.getenv("AWS_ACCESS_KEY_ID")
_AWS_SECRET = os.getenv("AWS_SECRET_ACCESS_KEY")
_DEFAULT_BEDROCK_MODEL = os.getenv("AWS_NOVA_MODEL", "amazon.nova-lite-v1:0")
_OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
_DEFAULT_OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen2.5:14b")


# ── LLM helpers (same pattern as agent_runner.py) ────────────────────────────

def _call_bedrock(model_id: str, system: str, user: str, temperature: float = 0.1, max_tokens: int = 4096) -> str:
    import boto3
    client = boto3.client(
        service_name="bedrock-runtime",
        region_name=_AWS_REGION,
        aws_access_key_id=_AWS_KEY,
        aws_secret_access_key=_AWS_SECRET,
    )
    response = client.converse(
        modelId=model_id,
        messages=[{"role": "user", "content": [{"text": user}]}],
        system=[{"text": system}] if system else [],
        inferenceConfig={"temperature": temperature, "maxTokens": max_tokens},
    )
    for block in response.get("output", {}).get("message", {}).get("content", []):
        if "text" in block:
            return block["text"]
    return ""


def _call_ollama(model_id: str, system: str, user: str, temperature: float = 0.1, max_tokens: int = 4096) -> str:
    import requests as req_lib
    payload = {
        "model": model_id,
        "messages": [
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        "stream": False,
        "options": {"temperature": temperature, "num_predict": max_tokens},
    }
    resp = req_lib.post(f"{_OLLAMA_BASE_URL}/api/chat", json=payload, timeout=180)
    resp.raise_for_status()
    return resp.json().get("message", {}).get("content", "")


def _extract_json(text: str) -> Any:
    try:
        return json.loads(text.strip())
    except Exception:
        pass
    for start_ch, end_ch in (("{", "}"), ("[", "]")):
        idx = text.find(start_ch)
        if idx == -1:
            continue
        depth, in_str, escape = 0, False, False
        for i, ch in enumerate(text[idx:], start=idx):
            if escape:
                escape = False
                continue
            if ch == "\\" and in_str:
                escape = True
                continue
            if ch == '"' and not escape:
                in_str = not in_str
            if not in_str:
                if ch == start_ch:
                    depth += 1
                elif ch == end_ch:
                    depth -= 1
                    if depth == 0:
                        try:
                            return json.loads(text[idx : i + 1])
                        except Exception:
                            break
    m = re.search(r"\{[\s\S]+\}", text)
    if m:
        try:
            return json.loads(m.group(0))
        except Exception:
            pass
    return None


def _llm_call(provider: str, model_id: str, system: str, user: str, temperature: float = 0.1, max_tokens: int = 4096) -> str:
    if provider == "bedrock":
        return _call_bedrock(model_id, system, user, temperature, max_tokens)
    return _call_ollama(model_id, system, user, temperature, max_tokens)


def _detect_provider_model(db, bid: str):
    """Read agent config from DB; fall back to Ollama default."""
    try:
        configs = db.get_agent_configs(bid)
        if configs:
            cfg = configs[0]
            return (cfg.get("model_provider") or "ollama").lower(), (cfg.get("model_id") or _DEFAULT_OLLAMA_MODEL)
    except Exception:
        pass
    return "ollama", _DEFAULT_OLLAMA_MODEL


# ── transcript formatter (reuse from agent_runner approach) ──────────────────

def _format_transcript(transcript: str, speaker_segments: List[Dict]) -> str:
    if speaker_segments:
        lines = []
        for seg in speaker_segments:
            role = seg.get("role", seg.get("speaker", "Speaker"))
            text = (seg.get("text") or "").strip()
            if text:
                lines.append(f"[{role.upper()}]: {text}")
        if lines:
            return "\n".join(lines)
    return transcript or ""


# ── Lead Insights (cross-call unified) ───────────────────────────────────────

_LEAD_INSIGHTS_SYSTEM = """You are an expert sales call analyst. You will be given summaries and transcripts from multiple sales calls with the same customer. Your task is to produce a unified lead intelligence report.

Return ONLY a valid JSON object — no markdown, no prose, no triple backticks."""

_LEGACY_DATA_CAPTURE_JSON = """    "competitors_mentioned": "<comma-separated list or 'None'>",
    "products_discussed": "<key products/services discussed>",
    "customer_sentiment_trend": "<improving | declining | neutral>",
    "key_concerns": "<top 1-3 concerns the customer keeps raising>"""


def _build_data_capture_json_fragment(field_rows: List[Dict]) -> str:
    """Inner lines for the data_capture object in the lead-insights JSON schema."""
    if not field_rows:
        return _LEGACY_DATA_CAPTURE_JSON
    lines = []
    for f in field_rows:
        key = (f.get("field_key") or "").strip()
        if not key:
            continue
        label = (f.get("display_name") or key).replace('"', "'")
        ftype = (f.get("field_type") or "text").strip()
        lines.append(
            f'    "{key}": "<string — synthesise across ALL calls for \'{label}\' (field type: {ftype}). '
            f'Use Not mentioned if absent.>"'
        )
    if not lines:
        return _LEGACY_DATA_CAPTURE_JSON
    return ",\n".join(lines)


def _build_lead_insights_user_message(
    lead_phone: str,
    bid: str,
    num_calls: int,
    calls_block: str,
    field_rows: List[Dict],
) -> str:
    dc_inner = _build_data_capture_json_fragment(field_rows)
    dc_rule = ""
    if field_rows:
        keys_csv = ", ".join(f'"{(f.get("field_key") or "").strip()}"' for f in field_rows if (f.get("field_key") or "").strip())
        dc_rule = f'\n- data_capture must contain EXACTLY these keys (and no others): {keys_csv}'
    else:
        dc_rule = '\n- data_capture must contain the four legacy keys shown in the schema.'

    return f"""LEAD PHONE: {lead_phone}
BUSINESS: {bid}
NUMBER OF CALLS: {num_calls}

=== CALL DATA ===
{calls_block}
=== END CALL DATA ===

Generate a JSON object with exactly these keys:

{{
  "objections": [
    {{
      "objection": "<exact objection raised by customer>",
      "type": "<one of: Price | Product | Trust | Urgency | Competitor | Need | Timing | Other>",
      "how_handled": "<what the salesperson said/did in response>",
      "customer_satisfied": <true|false>,
      "call_date": "<YYYY-MM-DD of the call this happened>"
    }}
  ],
  "bant": {{
    "budget": "<synthesised budget insight across all calls>",
    "authority": "<who is the decision maker>",
    "needs": "<what the customer needs/wants>",
    "timeline": "<when they want to act>"
  }},
  "data_capture": {{
{dc_inner}
  }},
  "path_to_conversion": {{
    "previous_summary": [
      "<bullet 1 summarising the conversation history>",
      "<bullet 2>",
      "<bullet 3 — add more as needed>"
    ],
    "action_steps": [
      {{
        "title": "<short action title>",
        "say": "<exact script the salesperson should say — 2-4 sentences, first person>",
        "rationale": "<one sentence explaining why this works>",
        "works_because": "<the psychological or sales principle at play>"
      }}
    ]
  }}
}}

Rules:
- action_steps must have exactly 3 items ordered by priority
- Each "say" must be a realistic, natural-sounding script in the context of this specific lead
- If a field has no data write "Not mentioned" or [] as appropriate
- Do not invent objections or facts not present in the call data{dc_rule}"""


def _normalize_lead_insights_data_capture(parsed: Dict[str, Any], field_rows: List[Dict]) -> None:
    """Ensure data_capture has every configured key; fill missing with Not mentioned."""
    if not field_rows:
        return
    dc = parsed.get("data_capture")
    if not isinstance(dc, dict):
        parsed["data_capture"] = {}
        dc = parsed["data_capture"]
    for f in field_rows:
        key = (f.get("field_key") or "").strip()
        if not key:
            continue
        val = dc.get(key)
        if val is None or (isinstance(val, str) and not val.strip()):
            dc[key] = "Not mentioned"


def generate_lead_insights(db, bid: str, lead_phone: str) -> Dict[str, Any]:
    """
    Run AI to produce unified lead insights for *lead_phone* across all their calls.
    Saves result to DB and returns the insights dict.
    Raises on AI failure.
    """
    calls = db.get_lead_calls_for_insights(bid, lead_phone, min_duration_seconds=20)
    if not calls:
        raise ValueError(f"No qualifying calls (>20s with transcript) found for {lead_phone}")

    # Build per-call block for the prompt
    call_blocks = []
    for i, c in enumerate(calls, 1):
        formatted = _format_transcript(c["transcript"], c.get("speaker_segments") or [])
        # Truncate very long transcripts to keep prompt manageable
        if len(formatted) > 3000:
            formatted = formatted[:3000] + "\n... [truncated]"
        block = (
            f"--- Call {i} | Date: {c['call_date'] if 'call_date' in c else c.get('call_starttime', '')[:10]} | "
            f"Duration: {c['duration_seconds']}s | Agent: {c['agentname']} ---\n"
            f"SUMMARY: {c.get('summary') or 'Not available'}\n"
            f"OBJECTIONS NOTE: {c.get('objections_concerns') or 'None noted'}\n"
            f"BANT NOTE: {c.get('bant_profile') or 'Not available'}\n"
            f"TRANSCRIPT:\n{formatted}"
        )
        call_blocks.append(block)

    calls_block = "\n\n".join(call_blocks)
    field_rows = []
    try:
        field_rows = db.get_data_capture_fields(bid)
    except Exception as exc:
        logger.warning("[%s] Could not load data_capture_fields: %s", bid, exc)

    user_msg = _build_lead_insights_user_message(
        lead_phone=lead_phone,
        bid=bid,
        num_calls=len(calls),
        calls_block=calls_block,
        field_rows=field_rows,
    )

    provider, model_id = _detect_provider_model(db, bid)
    logger.info("[%s] Generating lead insights for %s via %s/%s (%d calls)", bid, lead_phone, provider, model_id, len(calls))

    raw = _llm_call(provider, model_id, _LEAD_INSIGHTS_SYSTEM, user_msg, temperature=0.1, max_tokens=4096)
    parsed = _extract_json(raw)
    if not parsed or not isinstance(parsed, dict):
        raise RuntimeError(f"AI returned non-JSON response: {raw[:300]}")

    _normalize_lead_insights_data_capture(parsed, field_rows)

    db.save_lead_insights(bid, lead_phone, parsed)
    return parsed


# ── Rich Per-Call Summary ─────────────────────────────────────────────────────

_RICH_SUMMARY_SYSTEM = """You are an expert sales call analyst. Analyse the provided sales call transcript and produce a detailed structured summary.

Return ONLY a valid JSON object — no markdown, no prose, no triple backticks."""

_RICH_SUMMARY_USER_TMPL = """CALL DATE: {call_date}
AGENT: {agent_name}
DURATION: {duration}s

=== TRANSCRIPT ===
{transcript}
=== END TRANSCRIPT ===

Generate a JSON object with exactly these keys:

{{
  "call_summary": [
    {{
      "section_title": "<topic heading e.g. 'Introduction & Purpose', 'Product Discussion', 'Pricing Negotiation', 'Booking & Next Steps'>",
      "points": [
        "<bullet point 1>",
        "<bullet point 2>"
      ]
    }}
  ],
  "bant_insights": {{
    "budget": "<budget info from call, or 'Not mentioned'>",
    "authority": "<decision maker info, or 'Not mentioned'>",
    "needs": "<customer needs/interests>",
    "timeline": "<when customer wants to act, or 'Not mentioned'>"
  }},
  "competitors": "<competitor names mentioned, or 'None'>",
  "objection": "<main objection raised by customer, or 'None'>",
  "objection_categories": ["<e.g. Price Sensitivity>"],
  "objection_handling": "<how the salesperson responded to the objection>",
  "next_steps": [
    {{
      "action": "<what needs to happen next>",
      "due": "<date/time if mentioned, else 'Not specified'>"
    }}
  ]
}}

Rules:
- call_summary must have 3-6 sections covering the full conversation chronologically
- Each section must have 2-5 bullet points
- Bullets must be specific and factual — cite exact prices, names, packages if mentioned
- next_steps should only list things explicitly discussed in the call"""


def generate_rich_summary(
    transcript: str,
    speaker_segments: List[Dict],
    call_metadata: Optional[Dict] = None,
    provider: str = "ollama",
    model_id: str = _DEFAULT_OLLAMA_MODEL,
) -> Dict[str, Any]:
    """
    Generate a rich structured summary for a single call.
    Returns the parsed insights dict (does NOT save to DB — caller saves).
    """
    formatted = _format_transcript(transcript, speaker_segments)
    meta = call_metadata or {}
    user_msg = _RICH_SUMMARY_USER_TMPL.format(
        call_date=meta.get("call_date", "Unknown"),
        agent_name=meta.get("agent_name", "Unknown"),
        duration=meta.get("duration_seconds", "Unknown"),
        transcript=formatted[:6000] if len(formatted) > 6000 else formatted,
    )

    logger.info("Generating rich summary via %s/%s", provider, model_id)
    raw = _llm_call(provider, model_id, _RICH_SUMMARY_SYSTEM, user_msg, temperature=0.1, max_tokens=3000)
    parsed = _extract_json(raw)
    if not parsed or not isinstance(parsed, dict):
        raise RuntimeError(f"AI returned non-JSON: {raw[:300]}")
    return parsed


def generate_and_save_rich_summary(db, bid: str, callid: str) -> Dict[str, Any]:
    """
    Generate and persist a rich_summary for a single call.
    Reads transcript from DB, runs AI, saves result.
    """
    transcript, segments, metadata = _fetch_call_data(db, bid, callid)
    if not transcript:
        raise ValueError(f"No transcript found for callid={callid}")

    provider, model_id = _detect_provider_model(db, bid)
    result = generate_rich_summary(transcript, segments, metadata, provider, model_id)
    db.save_rich_summary(bid, callid, result)
    return result


def _fetch_call_data(db, bid: str, callid: str):
    """Fetch transcript, speaker_segments, and basic metadata for one call."""
    raw_table = f"{bid}_raw_calls"
    sarvam_table = f"{bid}_sarvamresponse"

    with db.get_connection() as conn:
        cursor = conn.cursor()

        has_sarvam = db._table_exists(cursor, sarvam_table)
        transcript_expr = "s.transcript" if has_sarvam else "NULL"
        speaker_expr = "s.speaker_segments" if has_sarvam else "NULL"
        duration_expr = "s.duration" if has_sarvam else "NULL"
        join_sarvam = f"LEFT JOIN `{sarvam_table}` s ON r.callid = s.callid" if has_sarvam else ""

        cursor.execute(
            f"""
            SELECT
                r.callid,
                r.agentname,
                r.call_starttime,
                TIMESTAMPDIFF(SECOND, r.call_starttime, r.call_endtime) AS duration_seconds,
                {transcript_expr} AS transcript,
                {speaker_expr} AS speaker_segments,
                {duration_expr} AS transcript_duration
            FROM `{raw_table}` r
            {join_sarvam}
            WHERE r.callid = %s
            LIMIT 1
            """,
            (callid,),
        )
        row = cursor.fetchone()

    if not row:
        return None, [], {}

    transcript = row.get("transcript") or ""
    segments = db._parse_json_field(row.get("speaker_segments")) or []
    metadata = {
        "call_date": str(row.get("call_starttime") or "")[:10],
        "agent_name": row.get("agentname") or "Unknown",
        "duration_seconds": int(row.get("duration_seconds") or row.get("transcript_duration") or 0),
    }
    return transcript, segments, metadata
