FlakeForge / scripts /plot_inference_episode.py
random70249's picture
Upload folder using huggingface_hub
7b5cf62 verified
#!/usr/bin/env python3
"""Visualize a FlakeForge inference episode JSON (stdout from `python inference.py`).
Generates a multi-panel figure: step rewards, pass rate & patch success, heatmap of
`reward_breakdown` keys, and key traces. Requires matplotlib.
python scripts/plot_inference_episode.py \\
-i data/inference_example_episode.json \\
-o docs/assets/inference_episode_dashboard.png
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
def _load_episode(path: Path) -> Dict[str, Any]:
with path.open(encoding="utf-8") as f:
return json.load(f)
def _trajectory_matrix(
episode: Dict[str, Any]
) -> Tuple[List[int], List[str], np.ndarray]:
traj: List[Dict[str, Any]] = episode.get("trajectory") or []
if not traj:
raise ValueError("episode has empty trajectory")
# Collect all breakdown keys; prefer trajectory rows
key_set: set[str] = set()
for row in traj:
rb = row.get("reward_breakdown") or {}
key_set |= set(rb.keys())
for row in episode.get("reward_breakdown_history") or []:
key_set |= set(row.keys())
if "total" in key_set:
keys = [k for k in sorted(key_set) if k != "total"] + ["total"]
else:
keys = sorted(key_set)
n = len(traj)
m = len(keys)
Z = np.zeros((m, n), dtype=np.float64)
steps: List[int] = []
for j, row in enumerate(traj):
steps.append(int(row.get("step", j + 1)))
rb = {k: float(v) for k, v in (row.get("reward_breakdown") or {}).items()}
for i, k in enumerate(keys):
Z[i, j] = rb.get(k, np.nan)
return steps, keys, Z
def plot_episode(episode: Dict[str, Any], out_path: Path, dpi: int = 150) -> None:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
try:
plt.style.use("seaborn-v0_8-whitegrid")
except OSError:
try:
plt.style.use("ggplot")
except OSError:
pass
traj: List[Dict[str, Any]] = episode.get("trajectory") or []
_, keys, Z = _trajectory_matrix(episode)
n = len(traj)
x = np.arange(1, n + 1)
steps_arr = [int(t.get("step", i)) for i, t in enumerate(traj, start=1)]
rewards = [float(t.get("reward", 0.0)) for t in traj]
cum = np.cumsum(rewards)
pass_rates = [float(t.get("pass_rate", 0.0)) for t in traj]
applied = [bool(t.get("patch_applied", False)) for t in traj]
conf = [float(t.get("predicted_confidence", 0.0)) for t in traj]
total_r = float(episode.get("total_reward", cum[-1] if len(cum) else 0.0))
done = str(episode.get("done_reason", ""))
final_pr = float(episode.get("final_pass_rate", pass_rates[-1] if pass_rates else 0.0))
fig, axes = plt.subplots(2, 2, figsize=(14, 9.5), dpi=dpi, constrained_layout=True)
fig.patch.set_facecolor("#fafbfc")
supt = (
f"Inference episode 路 total_reward={total_r:.2f} 路 final_pass_rate={final_pr:.2f} "
f"路 done_reason={done}"
)
fig.suptitle(supt, fontsize=12, fontweight="600", color="#0f172a", y=1.01)
# --- Panel: rewards + cumulative ---
ax0 = axes[0, 0]
c_bar = "#3b82f6"
c_cum = "#b45309"
ax0.bar(x, rewards, color=c_bar, alpha=0.85, edgecolor="white", linewidth=0.5, label="Step reward")
ax0.set_xticks(x)
ax0.set_xticklabels([str(s) for s in steps_arr], fontsize=9)
ax0.set_xlabel("Environment step", fontsize=10)
ax0.set_ylabel("Step reward", color=c_bar, fontsize=10)
ax0.tick_params(axis="y", labelcolor=c_bar)
ax0.axhline(0.0, color="#94a3b8", linewidth=0.8, linestyle="--")
ax0_t = ax0.twinx()
ax0_t.plot(x, cum, color=c_cum, linewidth=2.2, marker="o", markersize=5, label="Cumulative")
ax0_t.set_ylabel("Cumulative reward", color=c_cum, fontsize=10)
ax0_t.tick_params(axis="y", labelcolor=c_cum)
ax0.set_title("Reward per step and cumulative", fontsize=11, loc="left", color="#1e293b")
h1, l1 = ax0.get_legend_handles_labels()
h2, l2 = ax0_t.get_legend_handles_labels()
ax0.legend(h1 + h2, l1 + l2, loc="lower left", framealpha=0.95, fontsize=8)
# --- Panel: pass rate + patch applied + confidence ---
ax1 = axes[0, 1]
ax1.fill_between(x, 0, pass_rates, color="#8b5cf6", alpha=0.15, step="mid")
ax1.plot(x, pass_rates, color="#7c3aed", linewidth=2, marker="s", markersize=5, label="pass_rate")
for i, ok in enumerate(applied):
ax1.axvline(
x[i],
ymin=0.02,
ymax=0.12,
color=("#10b981" if ok else "#ef4444"),
linewidth=3,
alpha=0.9,
)
ax1.plot(x, conf, color="#0ea5e9", linewidth=1.5, linestyle="--", label="predicted_confidence", alpha=0.9)
ax1.set_ylim(-0.05, 1.12)
ax1.set_xticks(x)
ax1.set_xticklabels([str(s) for s in steps_arr], fontsize=9)
ax1.set_xlabel("Environment step", fontsize=10)
ax1.set_ylabel("Rate / confidence", fontsize=10)
ax1.set_title("Pass rate, confidence, patch applied (green=applied, red=failed)", fontsize=11, loc="left", color="#1e293b")
ax1.legend(loc="upper right", fontsize=8, framealpha=0.95)
# --- Panel: heatmap of breakdown components ---
ax2 = axes[1, 0]
vmax = max(np.nanmax(np.abs(Z)), 1e-6)
im = ax2.imshow(
Z,
aspect="auto",
cmap="RdYlBu_r",
vmin=-vmax,
vmax=vmax,
interpolation="nearest",
)
ax2.set_yticks(np.arange(len(keys)))
ax2.set_yticklabels(keys, fontsize=8)
ax2.set_xticks(np.arange(n))
ax2.set_xticklabels([str(s) for s in steps_arr], fontsize=8, rotation=0)
ax2.set_xlabel("Step", fontsize=10)
ax2.set_title("Reward breakdown (rows = JSON keys, cols = step)", fontsize=11, loc="left", color="#1e293b")
cbar = fig.colorbar(im, ax=ax2, fraction=0.034, pad=0.04)
cbar.set_label("Component value", fontsize=9)
# --- Panel: key traces (selected keys) ---
ax3 = axes[1, 1]
palette = [
("stability", "#dc2626"),
("oracle_reasoning", "#059669"),
("compile", "#d97706"),
("regression", "#7c2d12"),
("format", "#4f46e5"),
]
for name, color in palette:
ys = [float((t.get("reward_breakdown") or {}).get(name, np.nan)) for t in traj]
if all(np.isnan(ys)):
continue
ax3.plot(x, ys, "o-", color=color, linewidth=1.8, markersize=4, label=name, alpha=0.9)
ax3.axhline(0.0, color="#94a3b8", linewidth=0.8, linestyle="--")
ax3.set_xticks(x)
ax3.set_xticklabels([str(s) for s in steps_arr], fontsize=9)
ax3.set_xlabel("Environment step", fontsize=10)
ax3.set_ylabel("Component value", fontsize=10)
ax3.set_title("Traces: stability, oracle, compile, regression, format", fontsize=11, loc="left", color="#1e293b")
ax3.legend(loc="best", fontsize=8, ncol=2, framealpha=0.95)
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
def main() -> None:
p = argparse.ArgumentParser(description="Plot inference episode JSON to PNG dashboard.")
p.add_argument("-i", "--input", type=Path, required=True, help="Episode JSON (e.g. saved from stdout)")
p.add_argument("-o", "--output", type=Path, default=Path("docs/assets/inference_episode_dashboard.png"))
p.add_argument("--dpi", type=int, default=150)
args = p.parse_args()
episode = _load_episode(args.input)
plot_episode(episode, args.output, dpi=args.dpi)
print(f"Wrote {args.output.resolve()}")
if __name__ == "__main__":
main()