"""Generate the 7 manuscript headline figures."""
import os, gzip, pickle, csv
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

OUT = "/home/claude/kr_sim/results"
FIG = os.path.join(OUT, "figures")
os.makedirs(FIG, exist_ok=True)

plt.rcParams.update({
    "font.family": "DejaVu Sans",
    "font.size": 10, "axes.grid": True, "grid.alpha": 0.3,
    "axes.linewidth": 0.8, "lines.linewidth": 1.4,
})

CELL_LABELS = {
    "DS1_Uniform_ChgDis": "DS1 (RW9)",
    "DS2_Uniform_DisRT":  "DS2 (RW3)",
    "DS5_Skewed_High_RT": "DS5 (RW20)",
    "DS7_Skewed_Low_RT":  "DS7 (RW13)",
}


def _load(fn):
    with gzip.open(fn, "rb") as f:
        return pickle.load(f)


# ─────────────────────────────────────────────────────────────────────────────
# 1. Cross-cell: online K-R (grid) vs frozen-w2 K-R (cross-cell)
def fig_crosscell():
    grid = _load(os.path.join(OUT, "grid.pkl.gz"))
    cc = _load(os.path.join(OUT, "cross_cell.pkl.gz"))
    fig, axes = plt.subplots(1, 3, figsize=(13, 4))
    scens = ["S1", "S3", "S5"]
    titles = ["S1 (25 °C, SOH=1.00)", "S3 (5 °C, SOH=1.00)", "S5 (25 °C, SOH=0.80)"]
    cells = list(CELL_LABELS.keys())
    w = 0.4
    for ax, scen, ttl in zip(axes, scens, titles):
        on  = [np.mean([r["ov_mV"] for r in grid if r["cell_tag"]==c
                        and r["scenario"]==scen and r["method"]=="KR"]) for c in cells]
        fr  = [np.mean([r["ov_mV"] for r in cc   if r["cell_tag"]==c
                        and r["scenario"]==scen]) for c in cells]
        x = np.arange(len(cells))
        ax.bar(x - w/2, on, w, label="Online K-R (own cell)", color="#2f9e44",
               edgecolor="black", linewidth=0.5)
        ax.bar(x + w/2, fr, w, label="Frozen w2 (trained on other cells)", color="#51cf66",
               hatch="///", edgecolor="black", linewidth=0.5)
        ax.set_xticks(x)
        ax.set_xticklabels([CELL_LABELS[c] for c in cells], rotation=15, fontsize=9)
        ax.set_ylabel("Voltage overshoot (mV)")
        ax.set_title(ttl)
        ax.legend(fontsize=8, loc="best")
    plt.suptitle("Cross-cell generalisation — frozen-weight K-R on unseen NASA cells",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_crosscell_comparison.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
# 2. Ablation: K_only, KR_no_cross, KR_no_SOH, KR_linear, KR_full
def fig_ablation():
    rows = list(csv.DictReader(open(os.path.join(OUT, "ablation.csv"))))
    variants = ["K_only", "KR_no_cross", "KR_no_SOH", "KR_linear", "KR_full"]
    labels = ["K-only\n(no R)", "−cross\nterms", "−SOH\nfeat.", "Linear\nridge", "Full\nreservoir"]
    fig, axes = plt.subplots(1, 3, figsize=(13, 4))
    scens = ["S1", "S3", "S5"]
    colors = ["#e6553f", "#f08c00", "#fcc419", "#51cf66", "#2f9e44"]
    for ax, scen in zip(axes, scens):
        means = []; stds = []
        for v in variants:
            sub = [float(r["ov_mean"]) for r in rows if r["variant"]==v and r["scenario"]==scen]
            means.append(np.mean(sub)); stds.append(np.std(sub))
        x = np.arange(len(variants))
        ax.bar(x, means, yerr=stds, color=colors, edgecolor="black", linewidth=0.5, capsize=3)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, fontsize=9)
        ax.set_ylabel("Overshoot (mV)")
        ax.set_title(f"Scenario {scen}")
    plt.suptitle("Ablation: which K-R design choices carry the benefit?",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_ablation.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
# 3. Mismatch sweep: K-only vs K-R vs mismatch level, 3 scenarios
def fig_mismatch():
    rows = _load(os.path.join(OUT, "stress_mismatch.pkl.gz"))
    mms = [0.10, 0.25, 0.50]
    fig, axes = plt.subplots(1, 3, figsize=(13, 4))
    for ax, scen in zip(axes, ["S1", "S3", "S5"]):
        ko_means = []; ko_stds = []; kr_means = []; kr_stds = []
        for mm in mms:
            ko = [r["ov_mV"] for r in rows if r["method"]=="K-only"
                  and r["mm"]==mm and r["scenario"]==scen]
            kr = [r["ov_mV"] for r in rows if r["method"]=="KR"
                  and r["mm"]==mm and r["scenario"]==scen]
            ko_means.append(np.mean(ko)); ko_stds.append(np.std(ko))
            kr_means.append(np.mean(kr)); kr_stds.append(np.std(kr))
        x = np.arange(len(mms)); w = 0.35
        ax.bar(x - w/2, ko_means, w, yerr=ko_stds, color="#e6553f", label="K-only",
               edgecolor="black", linewidth=0.5, capsize=3)
        ax.bar(x + w/2, kr_means, w, yerr=kr_stds, color="#2f9e44", label="K-R",
               edgecolor="black", linewidth=0.5, capsize=3)
        ax.set_xticks(x)
        ax.set_xticklabels([f"{m*100:.0f}%" for m in mms])
        ax.set_xlabel("Model mismatch")
        ax.set_ylabel("Voltage overshoot (mV)")
        ax.set_title(f"Scenario {scen}")
        ax.legend(fontsize=9)
    plt.suptitle("Mismatch stress test — overshoot vs controller/plant parameter gap",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_mismatch_sweep.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
# 4. Noise robustness
def fig_noise():
    rows = _load(os.path.join(OUT, "stress_noise.pkl.gz"))
    scales = [1.0, 2.0, 4.0]
    fig, axes = plt.subplots(1, 3, figsize=(13, 4))
    for ax, scen in zip(axes, ["S1", "S3", "S5"]):
        ko_means, ko_stds, kr_means, kr_stds = [], [], [], []
        for s in scales:
            ko = [r["ov_mV"] for r in rows if r["method"]=="K-only"
                  and r["noise_scale"]==s and r["scenario"]==scen]
            kr = [r["ov_mV"] for r in rows if r["method"]=="KR"
                  and r["noise_scale"]==s and r["scenario"]==scen]
            ko_means.append(np.mean(ko)); ko_stds.append(np.std(ko))
            kr_means.append(np.mean(kr)); kr_stds.append(np.std(kr))
        x = np.arange(len(scales)); w = 0.35
        ax.bar(x - w/2, ko_means, w, yerr=ko_stds, color="#e6553f", label="K-only",
               edgecolor="black", linewidth=0.5, capsize=3)
        ax.bar(x + w/2, kr_means, w, yerr=kr_stds, color="#2f9e44", label="K-R",
               edgecolor="black", linewidth=0.5, capsize=3)
        ax.set_xticks(x)
        ax.set_xticklabels(["1× σ\n(2 mV/50 mA/0.5 °C)",
                             "2× σ", "4× σ"], fontsize=8)
        ax.set_xlabel("Sensor noise scale")
        ax.set_ylabel("Voltage overshoot (mV)")
        ax.set_title(f"Scenario {scen}")
        ax.legend(fontsize=9)
    plt.suptitle("Noise robustness — K-R vs K-only across 1×/2×/4× sensor σ",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_noise_robustness.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
# 5. Drift test
def fig_drift():
    rows = _load(os.path.join(OUT, "stress_drift.pkl.gz"))
    drifts = [0.0, 0.15, 0.30]
    fig, axes = plt.subplots(1, 2, figsize=(9.5, 4))
    for ax, scen in zip(axes, ["S1", "S5"]):
        ko_means, ko_stds, kr_means, kr_stds = [], [], [], []
        for d in drifts:
            ko = [r["ov_mV"] for r in rows if r["method"]=="K-only"
                  and r["drift_pct"]==d and r["scenario"]==scen]
            kr = [r["ov_mV"] for r in rows if r["method"]=="KR"
                  and r["drift_pct"]==d and r["scenario"]==scen]
            ko_means.append(np.mean(ko)); ko_stds.append(np.std(ko))
            kr_means.append(np.mean(kr)); kr_stds.append(np.std(kr))
        x = np.arange(len(drifts)); w = 0.35
        ax.bar(x - w/2, ko_means, w, yerr=ko_stds, color="#e6553f", label="K-only",
               edgecolor="black", linewidth=0.5, capsize=3)
        ax.bar(x + w/2, kr_means, w, yerr=kr_stds, color="#2f9e44", label="K-R",
               edgecolor="black", linewidth=0.5, capsize=3)
        ax.set_xticks(x)
        ax.set_xticklabels([f"{d*100:.0f}%" for d in drifts])
        ax.set_xlabel("Plant R_int drift during charge")
        ax.set_ylabel("Voltage overshoot (mV)")
        ax.set_title(f"Scenario {scen}   (controller params frozen — no re-ID)")
        ax.legend(fontsize=9)
    plt.suptitle("Drift test — K-R handles aging-like plant parameter change without re-ID",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_drift_test.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
# 6. Residual magnitude vs SOH  (physics insight)
def fig_residual_vs_soh():
    rows = list(csv.DictReader(open(os.path.join(OUT, "physics_sweep.csv"))))
    sohs = sorted({float(r["variable"]) for r in rows if r["sweep"]=="SOH"})
    fig, (a1, a2) = plt.subplots(1, 2, figsize=(11, 4))
    # Left panel: |R| RMS vs SOH, per cell
    cells = sorted({r["cell"] for r in rows if r["sweep"]=="SOH"})
    for cell in cells:
        Rm = [float(next(r["R_rms_mA_mean"] for r in rows
                          if r["sweep"]=="SOH" and float(r["variable"])==soh and r["cell"]==cell))
               for soh in sohs]
        a1.plot(sohs, Rm, marker="o", linewidth=1.6, label=CELL_LABELS.get(cell, cell))
    a1.axhline(500, color="gray", linestyle="--", linewidth=0.8, label="Safety clamp")
    a1.set_xlabel("State of health (SOH)")
    a1.set_ylabel("|R| RMS (mA)")
    a1.set_title("(a) Correction magnitude scales with aging")
    a1.invert_xaxis()
    a1.legend(fontsize=8, loc="best")

    # Right panel: OV reduction vs SOH
    for cell in cells:
        red = []
        for soh in sohs:
            r_soh = [r for r in rows if r["sweep"]=="SOH"
                      and float(r["variable"])==soh and r["cell"]==cell]
            if r_soh:
                ok = float(r_soh[0]["OV_Kon_mV_mean"])
                okr = float(r_soh[0]["OV_KR_mV_mean"])
                red.append(100 * (1 - okr/ok) if ok > 0 else 0)
            else:
                red.append(0)
        a2.plot(sohs, red, marker="s", linewidth=1.6, label=CELL_LABELS.get(cell, cell))
    a2.axhline(0, color="k", linewidth=0.5)
    a2.set_xlabel("State of health (SOH)")
    a2.set_ylabel("K-R overshoot reduction (%)")
    a2.set_title("(b) K-R benefit peaks at SOH ≈ 0.90")
    a2.invert_xaxis()
    a2.legend(fontsize=8, loc="best")
    plt.suptitle("Residual magnitude and K-R benefit vs state of health",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_residual_vs_soh.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
# 7. Residual magnitude vs temperature
def fig_residual_vs_T():
    rows = list(csv.DictReader(open(os.path.join(OUT, "physics_sweep.csv"))))
    Ts = sorted({float(r["variable"]) for r in rows if r["sweep"]=="T"})
    fig, (a1, a2) = plt.subplots(1, 2, figsize=(11, 4))
    cells = sorted({r["cell"] for r in rows if r["sweep"]=="T"})
    for cell in cells:
        Rm = [float(next(r["R_rms_mA_mean"] for r in rows
                          if r["sweep"]=="T" and float(r["variable"])==T and r["cell"]==cell))
               for T in Ts]
        a1.plot(Ts, Rm, marker="o", linewidth=1.6, label=CELL_LABELS.get(cell, cell))
    a1.axhline(500, color="gray", linestyle="--", linewidth=0.8, label="Safety clamp")
    a1.set_xlabel("Ambient temperature (°C)")
    a1.set_ylabel("|R| RMS (mA)")
    a1.set_title("(a) Correction magnitude vs temperature")
    a1.legend(fontsize=8, loc="best")

    for cell in cells:
        red = []
        for T in Ts:
            r_T = [r for r in rows if r["sweep"]=="T"
                   and float(r["variable"])==T and r["cell"]==cell]
            if r_T:
                ok = float(r_T[0]["OV_Kon_mV_mean"])
                okr = float(r_T[0]["OV_KR_mV_mean"])
                red.append(100 * (1 - okr/ok) if ok > 0 else 0)
            else:
                red.append(0)
        a2.plot(Ts, red, marker="s", linewidth=1.6, label=CELL_LABELS.get(cell, cell))
    a2.axhline(0, color="k", linewidth=0.5)
    a2.set_xlabel("Ambient temperature (°C)")
    a2.set_ylabel("K-R overshoot reduction (%)")
    a2.set_title("(b) K-R benefit is largest at cold temperatures")
    a2.legend(fontsize=8, loc="best")
    plt.suptitle("Residual magnitude and K-R benefit vs ambient temperature",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_residual_vs_temperature.png")
    plt.savefig(fn, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return fn


if __name__ == "__main__":
    for name, f in [("crosscell", fig_crosscell), ("ablation", fig_ablation),
                    ("mismatch", fig_mismatch),  ("noise", fig_noise),
                    ("drift", fig_drift),
                    ("residual_vs_soh", fig_residual_vs_soh),
                    ("residual_vs_T", fig_residual_vs_T)]:
        fn = f()
        print(f"  {name}: {os.path.basename(fn)}  ({os.path.getsize(fn)/1024:.0f} KB)")
