"""agent.py: plain RAG vs an AGENTIC RAG research agent.

plain_rag: the M7-M8 pattern. Retrieve ONCE on the question, then answer from those chunks. Simple,
           but it cannot answer a question whose answer needs a fact you only know what to look for
           AFTER reading the first result (a multi-hop question).

agentic_rag: retrieval is a TOOL the agent calls inside the M9 loop. The agent decides whether to
             search, reads the results, searches AGAIN with a better query if it needs more, and only
             then answers, citing the documents it used. It can also decide NOT to search at all.

Both are grounded in corpus.py. The client is injectable so it all runs offline.
"""

import os
from dotenv import load_dotenv
import anthropic
import corpus

load_dotenv()
MODEL = "claude-opus-4-8"

SEARCH_TOOL = [{
    "name": "search",
    "description": "Search the company knowledge base. Returns the most relevant documents.",
    "input_schema": {"type": "object", "properties": {"query": {"type": "string"}},
                     "required": ["query"]},
}]


def _format(results):
    return "\n".join(f"{doc_id}: {text}" for doc_id, text, _ in results) or "(no documents found)"


def plain_rag(question, client=None, k=2):
    """Retrieve once, then answer from those chunks. One shot, no second look."""
    client = client or anthropic.Anthropic()
    results = corpus.search(question, k=k)
    context = _format(results)
    resp = client.messages.create(
        model=MODEL, max_tokens=400,
        system="Answer the question using only the provided context. If it is not enough, say so.",
        messages=[{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}])
    return {"answer": resp.content[0].text, "sources": [r[0] for r in results]}


def agentic_rag(question, client=None, max_searches=4):
    """Let the agent search the knowledge base as many times as it needs, then answer with citations."""
    client = client or anthropic.Anthropic()
    messages = [{"role": "user", "content": question}]
    searches, sources = [], []

    while True:
        resp = client.messages.create(
            model=MODEL, max_tokens=500, tools=SEARCH_TOOL,
            system=("You are a research assistant. Use the search tool to find facts before answering, "
                    "and search again with a refined query if you need more. Cite the document ids you "
                    "used. Do not search for small talk."),
            messages=messages)
        messages.append({"role": "assistant", "content": resp.content})

        if resp.stop_reason == "tool_use":
            results = []
            for block in resp.content:
                if getattr(block, "type", None) != "tool_use":
                    continue
                query = block.input.get("query", "")
                searches.append(query)
                hits = corpus.search(query) if len(searches) <= max_searches else []
                for doc_id, _, _ in hits:
                    if doc_id not in sources:
                        sources.append(doc_id)
                results.append({"type": "tool_result", "tool_use_id": block.id, "content": _format(hits)})
            messages.append({"role": "user", "content": results})
            continue

        answer = "".join(b.text for b in resp.content if getattr(b, "type", None) == "text")
        return {"answer": answer, "searches": searches, "sources": sources}


if __name__ == "__main__":
    q = "Who leads the team that runs the billing service?"
    print("PLAIN  :", plain_rag(q))
    print("AGENTIC:", agentic_rag(q))
