"""evals.py: an EVALUATION harness for the agent.

Observability (tracer.py) tells you WHAT the agent did. Evaluation tells you whether that was
RIGHT. You write test cases (an input plus what you expect) and SCORERS that check the agent's
answer and its trace. Run the suite, get a scorecard. Run it again after any change to catch
regressions, the same idea as unit tests, but for an agent.

Scorers here are rule-based (deterministic, free, no API key). There is also an optional
LLM-as-judge scorer for open-ended answers; it calls the model, so it costs tokens (pilot).

Run (rule-based suite is free, no key needed if you pass a mock; live agent needs .env key):
    python evals.py
"""

import os
from dotenv import load_dotenv
import agent

load_dotenv()


# ---- the test set ("golden" cases: input + what a correct run looks like) ----
CASES = [
    {"id": "basic", "task": "What is 23 times 17? Use the multiply tool.",
     "expect_contains": "391", "expect_tool": "multiply", "expect_args": {"a": 23, "b": 17},
     "max_model_calls": 3},
    {"id": "small", "task": "What is 6 times 7? Use the multiply tool.",
     "expect_contains": "42", "expect_tool": "multiply", "expect_args": {"a": 6, "b": 7},
     "max_model_calls": 3},
]


# ---- scorers: each takes (case, answer, trace) and returns (label, passed, detail) ----
def score_answer_contains(case, answer, trace):
    want = case["expect_contains"]
    return ("answer_contains", want in (answer or ""), f"want '{want}' in answer")


def score_called_tool(case, answer, trace):
    want = case["expect_tool"]
    names = [s.name for s in trace.tool_calls()]
    return ("called_tool", want in names, f"want tool '{want}', saw {names}")


def score_tool_args(case, answer, trace):
    want = case.get("expect_args")
    if not want:
        return ("tool_args", True, "n/a")
    ok = any(s.name == case["expect_tool"] and s.inputs == want for s in trace.tool_calls())
    return ("tool_args", ok, f"want {case['expect_tool']}({want})")


def score_no_errors(case, answer, trace):
    bad = trace.errored()
    return ("no_errors", not bad, f"{len(bad)} errored span(s)")


def score_within_budget(case, answer, trace):
    cap = case.get("max_model_calls", 99)
    n = len(trace.model_calls())
    return ("within_budget", n <= cap, f"{n} model call(s), cap {cap}")


SCORERS = [score_answer_contains, score_called_tool, score_tool_args,
           score_no_errors, score_within_budget]


# ---- optional LLM-as-judge (for open-ended answers a rule can't check) -------
def score_llm_judge(client, task, answer, rubric):
    """Ask the model to grade an answer PASS/FAIL against a rubric. Costs tokens (pilot)."""
    verdict = client.messages.create(
        model="claude-opus-4-8", max_tokens=10,
        system="You are a strict grader. Reply with exactly PASS or FAIL.",
        messages=[{"role": "user", "content": f"Task: {task}\nAnswer: {answer}\nRubric: {rubric}"}],
    ).content[0].text.strip().upper()
    return ("llm_judge", verdict.startswith("PASS"), verdict)


# ---- the harness ------------------------------------------------------------
def run_suite(cases=CASES, run_fn=None, client=None, scorers=SCORERS):
    """Run every case through the agent, score it, print a scorecard. Returns a summary."""
    run_fn = run_fn or agent.run
    results, total, passed = [], 0, 0
    print("EVAL SCORECARD")
    for case in cases:
        answer, trace = run_fn(case["task"], client=client)
        checks = [s(case, answer, trace) for s in scorers]
        case_pass = all(ok for _, ok, _ in checks)
        results.append({"id": case["id"], "pass": case_pass, "checks": checks})
        mark = "PASS" if case_pass else "FAIL"
        print(f"  [{mark}] {case['id']}: {answer!r}")
        for label, ok, detail in checks:
            if not ok:
                print(f"        miss: {label} ({detail})")
        total += 1
        passed += 1 if case_pass else 0
    rate = passed / total if total else 0.0
    print(f"  ----\n  {passed}/{total} cases passed  ({rate*100:.0f}%)")
    return {"passed": passed, "total": total, "pass_rate": rate, "results": results}


if __name__ == "__main__":
    run_suite()
