"""corpus.py: a tiny synthetic knowledge base and a search tool over it.

This is the retrieval layer from M7-M8, shrunk to run offline with no embeddings and no key. The
`search` function is the TOOL the agent calls. Some questions need TWO searches to answer (you have
to read one document to learn what to search for next), which is exactly where a single-shot RAG
falls down and an agent that can search again shines.

Real systems use embeddings in a vector store (M7); this keyword search is a free, offline stand-in
with the same search(query) -> ranked documents shape.
"""

import re
import math

# A small knowledge base. Answering "who leads the team that runs billing?" needs two hops:
# first learn billing is run by the Payments team (D1), then learn who leads Payments (D3).
DOCS = {
    "D1": "The billing service is operated by the Payments team.",
    "D2": "The search service is operated by the Discovery team.",
    "D3": "The Payments team is led by Dana Okafor.",
    "D4": "The Discovery team is led by Sam Rivera.",
    "D5": "Company office hours are 9am to 5pm, Monday through Friday.",
}

_STOP = {"a", "an", "the", "is", "are", "of", "to", "and", "or", "by", "in", "on", "for", "who",
         "what", "that", "which", "team", "service", "leads", "led", "runs", "run", "operated"}
# note: we drop "team"/"service"/"leads" etc. so matches key on the DISTINCTIVE words
# (billing, payments, discovery, dana, ...), which is what makes the two-hop behaviour clear.


def _bag(text):
    counts = {}
    for tok in re.findall(r"[a-z0-9]+", text.lower()):
        if tok in _STOP:
            continue
        counts[tok] = counts.get(tok, 0) + 1
    return counts


def _cosine(a, b):
    shared = set(a) & set(b)
    num = sum(a[t] * b[t] for t in shared)
    na = math.sqrt(sum(v * v for v in a.values()))
    nb = math.sqrt(sum(v * v for v in b.values()))
    return num / (na * nb) if na and nb else 0.0


def search(query, k=2):
    """Return up to k (doc_id, text, score) tuples most relevant to the query (score > 0)."""
    q = _bag(query)
    scored = [(doc_id, text, _cosine(q, _bag(text))) for doc_id, text in DOCS.items()]
    hits = [t for t in scored if t[2] > 0]
    hits.sort(key=lambda t: t[2], reverse=True)
    return hits[:k]
