"""memory.py: the two kinds of agent MEMORY, in plain Python (no key, no internet, no libraries).

An agent forgets everything between calls unless you give it memory. There are two kinds:

  ShortTermMemory: the running conversation, kept inside one session. It cannot grow forever
                   (the model has a context limit and tokens cost money), so it trims to a budget.

  LongTermMemory:  facts that survive ACROSS sessions, saved to disk and recalled by meaning.
                   Here we use a tiny built-in similarity so it runs offline; production uses a
                   real vector store with embeddings (exactly what you built in M7).

These are deliberately small so you can read them end to end.
"""

import re
import math
import json


def approx_tokens(text):
    """Rough token estimate: about 4 characters per token. Good enough for budgeting."""
    return max(1, len(text) // 4)


# ---------- short-term memory: the conversation, kept under a token budget ----
class ShortTermMemory:
    def __init__(self, token_budget=120):
        self.token_budget = token_budget
        self.turns = []                       # list of {"role","content"}

    def add(self, role, content):
        self.turns.append({"role": role, "content": content})

    def window(self):
        """Return the most RECENT turns that fit the budget (older ones are dropped)."""
        kept, used = [], 0
        for turn in reversed(self.turns):     # newest first
            cost = approx_tokens(turn["content"])
            if used + cost > self.token_budget:
                break
            kept.append(turn)
            used += cost
        return list(reversed(kept))           # back to chronological order

    def used_tokens(self):
        return sum(approx_tokens(t["content"]) for t in self.window())


# ---------- long-term memory: facts that persist across sessions --------------
# Tiny stopword list so similarity keys on CONTENT words (name, hiking) not glue words (the, is).
_STOP = {
    "a", "an", "the", "is", "are", "was", "were", "be", "been", "am", "do", "does", "did",
    "i", "you", "he", "she", "it", "we", "they", "me", "my", "your", "our", "their", "his", "her",
    "of", "to", "in", "on", "for", "and", "or", "but", "with", "about", "at", "by", "as",
    "what", "who", "when", "where", "why", "how", "this", "that", "these", "those",
    "have", "has", "had", "know", "tell", "remind", "please", "would", "could", "can", "will",
}


def _bag(text):
    counts = {}
    for tok in re.findall(r"[a-z0-9]+", text.lower()):
        if tok in _STOP:
            continue
        counts[tok] = counts.get(tok, 0) + 1
    return counts


def _cosine(a, b):
    shared = set(a) & set(b)
    num = sum(a[t] * b[t] for t in shared)
    na = math.sqrt(sum(v * v for v in a.values()))
    nb = math.sqrt(sum(v * v for v in b.values()))
    return num / (na * nb) if na and nb else 0.0


class LongTermMemory:
    def __init__(self):
        self.facts = []                       # list of strings

    def remember(self, fact):
        if fact not in self.facts:
            self.facts.append(fact)

    def recall(self, query, k=3):
        """Return up to k facts most relevant to the query (by word-overlap similarity)."""
        q = _bag(query)
        scored = [(f, _cosine(q, _bag(f))) for f in self.facts]
        hits = [(f, s) for f, s in scored if s > 0]
        hits.sort(key=lambda fs: fs[1], reverse=True)
        return [f for f, _ in hits[:k]]

    # persistence: this is what makes memory survive a restart
    def save(self, path):
        with open(path, "w") as fh:
            json.dump(self.facts, fh)

    def load(self, path):
        with open(path) as fh:
            self.facts = json.load(fh)
        return self
