"""pricing.py: a tiny, dependency-free cost and latency model for Claude calls.

M20 showed you how to SEE cost (tokens per run). This module is how to ESTIMATE and REDUCE it. We
price calls with the course's model table and Anthropic's prompt-caching multipliers, so you can
compare optimizations offline, with no key and no spend.

Prices are per 1,000,000 tokens and are ILLUSTRATIVE (the course model table); always check current
pricing before you rely on a number. Cache multipliers approximate Anthropic prompt caching: a cache
WRITE costs about 1.25x the input price, a cache READ about 0.1x.
"""

# input / output dollars per 1,000,000 tokens (course model table)
PRICES = {
    "claude-opus-4-8":   {"in": 5.0, "out": 25.0},
    "claude-sonnet-4-6": {"in": 3.0, "out": 15.0},
    "claude-haiku-4-5":  {"in": 1.0, "out": 5.0},
}

CACHE_WRITE_MULT = 1.25     # writing a prefix into the cache costs a bit more than normal input
CACHE_READ_MULT = 0.1       # reading a cached prefix is about a tenth of the input price

# illustrative latency: a fixed base per model plus a little per output token (seconds)
LATENCY = {
    "claude-opus-4-8":   {"base": 3.0, "per_out": 0.006},
    "claude-sonnet-4-6": {"base": 1.8, "per_out": 0.004},
    "claude-haiku-4-5":  {"base": 0.7, "per_out": 0.002},
}


def approx_tokens(text):
    """Rough token estimate: about 4 characters per token."""
    return max(1, len(text) // 4)


def dollars(model, in_tokens=0, out_tokens=0, cache_write_tokens=0, cache_read_tokens=0):
    """Cost of one call in dollars. Cached prefix tokens are billed at the cache multipliers."""
    p = PRICES[model]
    total = (in_tokens * p["in"]
             + out_tokens * p["out"]
             + cache_write_tokens * p["in"] * CACHE_WRITE_MULT
             + cache_read_tokens * p["in"] * CACHE_READ_MULT)
    return total / 1_000_000


def latency(model, out_tokens):
    """Estimated wall-clock seconds for one call (output tokens dominate generation time)."""
    l = LATENCY[model]
    return l["base"] + out_tokens * l["per_out"]
