"""tracer.py: a tiny, dependency-free OBSERVABILITY tool for agents.

Observability means being able to SEE what your agent did: every model call and every tool
call, with its inputs, outputs, how long it took, and how many tokens it used. Production
teams use tools like LangSmith, Langfuse, Arize Phoenix, or OpenTelemetry for this; here we
build the same idea in ~40 lines so you understand what those tools record.

A Trace is a list of spans. A span is one step (one model call or one tool call).

No API key, no internet, no libraries beyond the standard library.
"""

import time


class Span:
    """One recorded step: what happened, with timing and (optional) token counts."""

    def __init__(self, kind, name, inputs):
        self.kind = kind            # "model" or "tool"
        self.name = name            # e.g. "claude-opus-4-8" or "multiply"
        self.inputs = inputs
        self.output = None
        self.tokens = 0
        self.status = "ok"          # "ok" or "error"
        self.start = time.time()
        self.duration = 0.0

    def finish(self, output, tokens=0, status="ok"):
        self.output = output
        self.tokens = tokens
        self.status = status
        self.duration = time.time() - self.start


class Trace:
    """Collects spans for one agent run and can print/summarize them."""

    def __init__(self, label="agent-run"):
        self.label = label
        self.spans = []

    def record(self, kind, name, inputs):
        """Open a span; remember to call .finish(...) on the returned span."""
        span = Span(kind, name, inputs)
        self.spans.append(span)
        return span

    # --- read the trace ------------------------------------------------------
    def total_tokens(self):
        return sum(s.tokens for s in self.spans)

    def tool_calls(self):
        return [s for s in self.spans if s.kind == "tool"]

    def model_calls(self):
        return [s for s in self.spans if s.kind == "model"]

    def errored(self):
        return [s for s in self.spans if s.status == "error"]

    def print_tree(self):
        """Human-readable trace, the thing you stare at when an agent misbehaves."""
        print(f"TRACE: {self.label}")
        for i, s in enumerate(self.spans, 1):
            line = f"  {i}. [{s.kind}] {s.name}  ({s.duration*1000:.0f} ms"
            if s.tokens:
                line += f", {s.tokens} tok"
            line += f", {s.status})"
            print(line)
            print(f"       in:  {s.inputs}")
            print(f"       out: {s.output}")
        print(f"  totals: {len(self.model_calls())} model call(s), "
              f"{len(self.tool_calls())} tool call(s), {self.total_tokens()} tokens, "
              f"{len(self.errored())} error(s)")
