"""optimize.py: the levers that make an agent cheaper and faster, and a workload to measure them on.

Levers:
  route        : send EASY steps to a small fast model (Haiku), HARD steps to a strong one (Opus)
  prompt cache : a stable prefix (system + context) is written once, then read cheaply on later calls
  trim         : cut tokens you do not need (shorter context, shorter prompts)

We estimate a 5-step support pipeline four ways (naive, cache-only, route-only, both) so you can see
what each lever is worth. Pure arithmetic over pricing.py: no key, no spend.
"""

import pricing

# Every step also sends a stable 2,000-token prefix (system instructions + retrieved context).
PREFIX_TOKENS = 2000

# The pipeline. extra_in is the per-step input on top of the shared prefix; out is output tokens.
STEPS = [
    {"name": "classify intent", "difficulty": "easy", "extra_in": 50,  "out": 10},
    {"name": "detect language", "difficulty": "easy", "extra_in": 30,  "out": 10},
    {"name": "check sentiment", "difficulty": "easy", "extra_in": 40,  "out": 10},
    {"name": "decide resolution", "difficulty": "hard", "extra_in": 200, "out": 300},
    {"name": "write the reply",   "difficulty": "hard", "extra_in": 150, "out": 250},
]


def route(difficulty):
    """Cheap model for easy steps, strong model for hard ones."""
    return "claude-haiku-4-5" if difficulty == "easy" else "claude-opus-4-8"


def trim(text, budget_tokens):
    """Keep only the last budget_tokens worth of text (about 4 chars per token)."""
    max_chars = budget_tokens * 4
    return text[-max_chars:]


def estimate(steps=STEPS, prefix_tokens=PREFIX_TOKENS, routing=False, caching=False):
    """Total dollars and seconds for the pipeline under a given optimization config."""
    total_cost, total_latency = 0.0, 0.0
    written = set()                         # which models already paid to cache the prefix
    for step in steps:
        model = route(step["difficulty"]) if routing else "claude-opus-4-8"
        if caching:
            if model not in written:        # first call on this model writes the prefix
                cw, cr = prefix_tokens, 0
                written.add(model)
            else:                           # later calls read it cheaply
                cw, cr = 0, prefix_tokens
            cost = pricing.dollars(model, in_tokens=step["extra_in"], out_tokens=step["out"],
                                   cache_write_tokens=cw, cache_read_tokens=cr)
        else:                               # no cache: resend the whole prefix every call
            cost = pricing.dollars(model, in_tokens=prefix_tokens + step["extra_in"],
                                   out_tokens=step["out"])
        total_cost += cost
        total_latency += pricing.latency(model, step["out"])
    return {"cost": total_cost, "latency": total_latency}


CONFIGS = {
    "naive (all Opus, no cache)":      dict(routing=False, caching=False),
    "cache only (all Opus, cached)":   dict(routing=False, caching=True),
    "route only (Haiku for easy)":     dict(routing=True,  caching=False),
    "both (route + cache)":            dict(routing=True,  caching=True),
}


if __name__ == "__main__":
    base = estimate(**CONFIGS["naive (all Opus, no cache)"])
    for label, cfg in CONFIGS.items():
        e = estimate(**cfg)
        save = (1 - e["cost"] / base["cost"]) * 100
        print(f"{label:32s} ${e['cost']:.5f}/run  {e['latency']:.1f}s  saves {save:.0f}%")
