"""
Pradhya · The Agents Workshop · Unit 02
========================================

A production-shaped eval harness for the Capable Series research agent
(or any agent that takes a string input and returns a string + trace).

Reads test cases from test_cases.jsonl. For each case, runs the agent,
times it, grades it, and writes a row to results.jsonl. At the end,
prints a summary table.

Run:
    cd workshops/code-examples
    python eval_harness.py path/to/test_cases.jsonl

The four metrics tracked:
    - accuracy      (% of cases that pass grading)
    - cost (USD)    (input + output tokens * model price; approximate)
    - latency_s     (wall-clock per case, p50 + p90)
    - refusals      (cases where the agent refused legitimate input)
"""

from __future__ import annotations

import argparse
import json
import os
import pathlib
import statistics
import sys
import time
from collections import Counter
from typing import Callable, Optional

# Adjust the import path as needed. We assume research_agent.py is
# reachable. If you run this from /workshops/code-examples and the
# agent lives at /agents, add the path explicitly.
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[2] / "agents"))


def grade(case: dict, output: str, trace: list[dict]) -> tuple[bool, str]:
    """Return (passed, reason). The reason is for the audit log.

    `trace` is the agent's tool trace: a list of {"tool": name, "input": {...}}.
    """
    if "expected_tool" in case:
        ok = any(t.get("tool") == case["expected_tool"] for t in trace)
        return ok, f"expected tool {case['expected_tool']} "+ ("called" if ok else "not called")

    if "expected_substr" in case:
        ok = case["expected_substr"].lower() in (output or "").lower()
        return ok, "substring " + ("present" if ok else "absent")

    if case.get("expected_refusal"):
        markers = ("cannot help", "i won't", "i can't", "won't help", "refuse")
        ok = any(m in (output or "").lower() for m in markers)
        return ok, "refusal detected" if ok else "no refusal detected (regression risk)"

    return False, "no grader matched this case"


def _approx_cost_usd(in_tokens: int, out_tokens: int, model: str) -> float:
    # As of May 2026; verify against https://platform.claude.com/docs/en/about-claude/pricing
    prices = {
        "claude-sonnet-4-6": (3.00, 15.00),  # $/M in, $/M out
        "claude-opus-4-7":   (5.00, 25.00),
        "claude-haiku-4-5":  (1.00,  5.00),
    }
    in_p, out_p = prices.get(model, (3.00, 15.00))
    return (in_tokens * in_p + out_tokens * out_p) / 1_000_000


def run_case(
    runner: Callable[[str], tuple[str, list[dict]]],
    case: dict,
    usage_fn: Optional[Callable[[], dict]] = None,
) -> dict:
    t0 = time.time()
    error: Optional[str] = None
    try:
        output, trace = runner(case["input"])
    except Exception as e:
        output, trace = "", []
        error = f"{type(e).__name__}: {e}"
    elapsed = time.time() - t0

    passed, reason = grade(case, output, trace) if error is None else (False, error)

    # Real token usage for this run, read from the agent if it reports it
    # (research_agent leaves it in LAST_USAGE). Falls back to zeros.
    usage      = (usage_fn() if usage_fn else {}) or {}
    in_tokens  = int(usage.get("input_tokens", 0) or 0)
    out_tokens = int(usage.get("output_tokens", 0) or 0)
    model      = usage.get("model") or "claude-sonnet-4-6"

    return {
        "id":        case["id"],
        "pass":      passed,
        "reason":    reason,
        "latency_s": round(elapsed, 2),
        "cost_usd":  round(_approx_cost_usd(in_tokens, out_tokens, model), 4),
        "tool_calls": sum(1 for t in trace if t.get("tool")),
        "model":     model,
        "error":     error,
    }


def summarise(rows: list[dict]) -> None:
    n = len(rows)
    if n == 0:
        print("No results.")
        return
    passed     = sum(1 for r in rows if r["pass"])
    failed     = n - passed
    latency    = [r["latency_s"] for r in rows]
    cost       = sum(r["cost_usd"] for r in rows)
    refusals   = sum(1 for r in rows if "refusal detected" in r["reason"])
    model_mix  = Counter(r.get("model", "?") for r in rows)

    p50 = statistics.median(latency)
    p90 = statistics.quantiles(latency, n=10)[-1] if n >= 10 else max(latency)

    print()
    print("=" * 60)
    print(f"  accuracy   : {passed}/{n}  ({passed/n*100:.1f}%)")
    print(f"  total cost : ${cost:.4f}")
    print(f"  latency    : p50 {p50:.2f}s  · p90 {p90:.2f}s")
    print(f"  refusals   : {refusals}")
    print(f"  models     : {dict(model_mix)}")
    print("=" * 60)

    if failed:
        print("\nFAILURES:")
        for r in rows:
            if not r["pass"]:
                print(f"  {r['id']:>20s}  {r['reason']}")


def load_cases(path: pathlib.Path) -> list[dict]:
    cases = []
    for line in path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if line and not line.startswith("#"):
            cases.append(json.loads(line))
    return cases


def main() -> None:
    p = argparse.ArgumentParser(description="Run an eval suite against the research agent")
    p.add_argument("cases", nargs="?", default="test_cases.jsonl")
    p.add_argument("--out", default="results.jsonl")
    args = p.parse_args()

    # This eval calls the live Claude API. Fail fast and clearly if no key.
    if not os.environ.get("ANTHROPIC_API_KEY"):
        sys.exit("ANTHROPIC_API_KEY is not set — this eval calls the live API. "
                 "Set it (console.anthropic.com) and re-run.")

    # Import the agent under test. The sys.path insert at the top of this file
    # makes /agents importable no matter where you run from.
    import research_agent  # type: ignore

    # research_agent.run_one(prompt) returns (answer_text, trace) and leaves
    # the run's token usage in research_agent.LAST_USAGE.
    def runner(prompt: str) -> tuple[str, list[dict]]:
        return research_agent.run_one(prompt)

    def usage_fn() -> dict:
        usage = dict(getattr(research_agent, "LAST_USAGE", {}) or {})
        usage.setdefault("model", getattr(research_agent, "MODEL", "claude-sonnet-4-6"))
        return usage

    cases = load_cases(pathlib.Path(args.cases))
    out_path = pathlib.Path(args.out)
    rows: list[dict] = []
    with out_path.open("w", encoding="utf-8") as out:
        for i, case in enumerate(cases, 1):
            print(f"[{i}/{len(cases)}] {case['id']:>20s} ...", end=" ", flush=True)
            row = run_case(runner, case, usage_fn)
            print("PASS" if row["pass"] else "FAIL")
            out.write(json.dumps(row) + "\n")
            rows.append(row)

    summarise(rows)


if __name__ == "__main__":
    main()
