"""rag_eval.py, M8: measure your RAG app, then improve it.

M7 built a RAG app. But is it any *good*? You can't improve what you don't measure.
This file adds an EVAL SET (questions with known right answers) and a scorecard with
two numbers:
  - RETRIEVAL hit rate, did we fetch the chunk that holds the answer?
  - ANSWER match rate, did the final answer contain the right fact?

Then it runs two configurations, a weak BASELINE and a TUNED one (smaller chunks,
retrieve more, then rerank), so you can compare scores and learn how to push them up.
(Tuning doesn't *always* help, that's exactly why you measure instead of guessing.)

Run (venv active, key in .env, chromadb installed, from this folder):
    python rag_eval.py
"""

import os
from dotenv import load_dotenv
import anthropic
import chromadb

load_dotenv()
client = anthropic.Anthropic()
MODEL = "claude-opus-4-8"


# ---------- chunking (a lever you can tune) ----------------------------------
def chunk_paragraphs(text):
    """M7's chunking: one chunk per paragraph (can be coarse)."""
    return [p.strip() for p in text.split("\n\n") if p.strip()]


def chunk_small(text, size_words=20, overlap_words=6):
    """Smaller, overlapping chunks: finer-grained retrieval, less topic-mixing."""
    words = text.split()
    chunks, start = [], 0
    step = size_words - overlap_words
    while start < len(words):
        chunks.append(" ".join(words[start:start + size_words]))
        start += step
    return chunks


# ---------- index / retrieve / rerank / answer -------------------------------
def build_index(chunks):
    chroma = chromadb.Client()
    # a fresh, uniquely-named collection each call so configs don't collide
    collection = chroma.create_collection(f"docs-{len(chunks)}-{id(chunks)}")
    collection.add(documents=chunks, ids=[f"c{i}" for i in range(len(chunks))])
    return collection


def retrieve(collection, question, k):
    results = collection.query(query_texts=[question], n_results=k)
    return results["documents"][0]


def rerank(question, candidates, keep):
    """A simple second pass: re-score candidates by word overlap, keep the best `keep`.
    (Production apps use dedicated rerank models / cross-encoders, same idea: a broad,
    cheap first pass, then a sharper second pass.)"""
    q_words = set(question.lower().split())
    scored = sorted(candidates, key=lambda c: len(q_words & set(c.lower().split())), reverse=True)
    return scored[:keep]


def answer(question, context_chunks):
    context = "\n\n".join(context_chunks)
    prompt = (
        "Answer using ONLY the context below. "
        'If it is not in the context, say "I don\'t know based on the document."\n\n'
        f"Context:\n{context}\n\nQuestion: {question}"
    )
    response = client.messages.create(
        model=MODEL, max_tokens=400, messages=[{"role": "user", "content": prompt}],
    )
    return response.content[0].text


# ---------- the eval set: questions with known right answers -----------------
# source_must_contain: a phrase from the chunk that SHOULD be retrieved (None = not in doc)
# answer_must_contain: a phrase the final answer should include
EVAL_SET = [
    {"q": "What time does the café open on weekdays?",
     "source_must_contain": "7:00 AM", "answer_must_contain": "7"},
    {"q": "What's the guest wifi password?",
     "source_must_contain": "freshbeans2024", "answer_must_contain": "freshbeans2024"},
    {"q": "Can I get my money back on a coffee I didn't like?",
     "source_must_contain": "refund", "answer_must_contain": "refund"},
    {"q": "How long is a break on a long shift?",
     "source_must_contain": "30-minute", "answer_must_contain": "30"},
    {"q": "Who is the café's CEO?",            # not in the document, grounding check
     "source_must_contain": None, "answer_must_contain": "don't know"},
]


def run_config(name, text, chunk_fn, k, rerank_keep):
    """Run the whole eval set under one configuration; print and return the scorecard."""
    chunks = chunk_fn(text)
    collection = build_index(chunks)
    retrieval_hits = answer_hits = scorable_retrieval = 0

    for item in EVAL_SET:
        found = retrieve(collection, item["q"], k)
        if rerank_keep:
            found = rerank(item["q"], found, rerank_keep)

        # retrieval metric (only for questions whose answer is in the doc)
        if item["source_must_contain"] is not None:
            scorable_retrieval += 1
            if any(item["source_must_contain"].lower() in c.lower() for c in found):
                retrieval_hits += 1

        # answer metric
        reply = answer(item["q"], found)
        if item["answer_must_contain"].lower() in reply.lower():
            answer_hits += 1

    print(f"  {name:<10} chunks={len(chunks):<3} k={k} rerank={rerank_keep or 'off'}  "
          f"|  retrieval {retrieval_hits}/{scorable_retrieval}  "
          f"answers {answer_hits}/{len(EVAL_SET)}")
    return retrieval_hits, answer_hits


if __name__ == "__main__":
    text = open("sample_notes.txt", encoding="utf-8").read()
    print("Scorecard (higher is better):")
    run_config("baseline", text, chunk_paragraphs, k=1, rerank_keep=None)
    run_config("tuned",    text, chunk_small,      k=6, rerank_keep=3)
    print("\nThe scorecard is the judge. Change ONE lever, re-run, keep what scores higher")
    print("on YOUR document. Retrieving more (higher k) is usually the most reliable lift.")
