"""rag.py, M7: a Q&A app over YOUR document (Retrieval-Augmented Generation).

The model doesn't know your document. RAG fixes that in three steps:
  1. RETRIEVE, find the few chunks of your document most relevant to the question
     (a vector store, Chroma, does the semantic search for us).
  2. AUGMENT, paste those chunks into the prompt as context.
  3. GENERATE, ask Claude to answer using only that context.

Run (venv active, key in .env, chromadb installed, from this folder):
    python rag.py
"""

import os
from dotenv import load_dotenv
import anthropic
import chromadb

load_dotenv()
client = anthropic.Anthropic()
MODEL = "claude-opus-4-8"


def chunk_document(text):
    """Split a document into chunks (here: one chunk per paragraph)."""
    parts = [p.strip() for p in text.split("\n\n")]
    return [p for p in parts if p]           # drop blank pieces


def build_index(chunks):
    """Put the chunks into a Chroma collection. Chroma turns each chunk into an
    embedding (a list of numbers capturing its meaning) automatically, no extra key."""
    chroma = chromadb.Client()               # in-memory store for this session
    collection = chroma.get_or_create_collection("document")
    collection.add(documents=chunks, ids=[f"chunk-{i}" for i in range(len(chunks))])
    return collection


def retrieve(collection, question, k=3):
    """Return the k chunks whose meaning is closest to the question."""
    results = collection.query(query_texts=[question], n_results=k)
    return results["documents"][0]           # documents for our single query


def answer(question, context_chunks):
    """Ask Claude to answer using ONLY the retrieved chunks."""
    context = "\n\n".join(context_chunks)
    prompt = (
        "Answer the question using ONLY the context below. "
        'If the answer is not in the context, say "I don\'t know based on the document."\n\n'
        f"Context:\n{context}\n\n"
        f"Question: {question}"
    )
    response = client.messages.create(
        model=MODEL, max_tokens=500,
        messages=[{"role": "user", "content": prompt}],
    )
    return response.content[0].text


if __name__ == "__main__":
    text = open("sample_notes.txt", encoding="utf-8").read()
    chunks = chunk_document(text)
    print(f"Split the document into {len(chunks)} chunks.")
    collection = build_index(chunks)
    print("Indexed! Ask questions about the document (type 'quit' to leave).\n")
    while True:
        question = input("Question: ")
        if question.strip().lower() in {"quit", "exit"}:
            break
        found = retrieve(collection, question, k=3)
        print(f"  [retrieved {len(found)} relevant chunks]")
        print("Answer:", answer(question, found), "\n")
