"""
Generate publication-grade figures from the 960-trial grid.

Figures produced (matching §8 of the simulation parameters document):
  fig_{CELL}_{SCEN}.png         — 4-panel V/I/SOC/eK traces
  fig_{CELL}_{SCEN}_R.png       — R-correction + temperature trace
  fig_summary_bars.png          — summary bars (RMSE + OV)
  fig_cross_dataset.png         — cross-dataset overshoot-reduction
  fig_autocorr.png              — eK autocorrelation proof of structure
All figures use a consistent colour code: CC-CV gray, MPC blue, K-only coral, K-R green (thicker).
"""
import os, gzip, pickle
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

OUT = "/home/claude/kr_sim/results"
FIG = os.path.join(OUT, "figures")
os.makedirs(FIG, exist_ok=True)

# Consistent style
plt.rcParams.update({
    "font.family": "DejaVu Sans",
    "font.size": 10,
    "axes.grid": True,
    "grid.alpha": 0.3,
    "axes.linewidth": 0.8,
    "lines.linewidth": 1.4,
})

COLOUR = {
    "CC-CV":  "#7a7a7a",
    "MPC":    "#3b5bdb",
    "K-only": "#e6553f",
    "KR":     "#2f9e44",
}
WIDTH = {"CC-CV": 1.4, "MPC": 1.4, "K-only": 1.4, "KR": 2.0}
LABEL = {"CC-CV": "CC-CV", "MPC": "MPC", "K-only": "K-only", "KR": "K-R ★"}

SCEN_TITLE = {
    "S1": "S1: Standard (25 °C, SOH=1.00, 15% model mismatch)",
    "S3": "S3: Cold start (5 °C, SOH=1.00, 25% model mismatch)",
    "S5": "S5: Aged (25 °C, SOH=0.80, 35% model mismatch)",
}


def load():
    with gzip.open(os.path.join(OUT, "grid.pkl.gz"), "rb") as f:
        return pickle.load(f)


def group(rows, keys):
    from collections import defaultdict
    out = defaultdict(list)
    for r in rows:
        out[tuple(r[k] for k in keys)].append(r)
    return dict(out)


def median_trial(trials):
    """Pick the trial nearest to the median overshoot (robust representative)."""
    ovs = np.array([r["ov_mV"] for r in trials])
    idx = int(np.argmin(np.abs(ovs - np.median(ovs))))
    return trials[idx]


# ─────────────────────────────────────────────────────────────────────────────
def fig_four_panel(rows, cell_tag, cell_name, scen):
    """4-panel: V, I, SOC, eK vs time for all 4 controllers."""
    gb = group(rows, ["method"])
    fig, axes = plt.subplots(2, 2, figsize=(11.5, 7.5))
    fig.suptitle(f"{cell_tag} — {cell_name}   {SCEN_TITLE[scen]}",
                 fontsize=11, fontweight="bold", y=0.995)

    for method in ["CC-CV", "MPC", "K-only", "KR"]:
        trials = gb.get((method,), [])
        if not trials:
            continue
        r = median_trial(trials)
        t = r["t"]
        col, lw, lab = COLOUR[method], WIDTH[method], LABEL[method]
        axes[0,0].plot(t, r["V"],    color=col, linewidth=lw, label=lab)
        axes[0,1].plot(t, r["I"],    color=col, linewidth=lw, label=lab)
        axes[1,0].plot(t, r["SOC"]*100, color=col, linewidth=lw, label=lab)
        axes[1,1].plot(t, r["eK"]*1000, color=col, linewidth=lw, label=lab)

    axes[0,0].axhline(4.20, color="k", linestyle=":", linewidth=0.8, label="V_max")
    axes[0,0].set_ylabel("Terminal voltage (V)")
    axes[0,0].set_xlabel("Time (s)")
    axes[0,0].set_title("(a) Voltage")

    axes[0,1].set_ylabel("Charge current (A)")
    axes[0,1].set_xlabel("Time (s)")
    axes[0,1].set_title("(b) Current")

    axes[1,0].axhline(80, color="k", linestyle=":", linewidth=0.8, label="SOC=80%")
    axes[1,0].set_ylabel("State of charge (%)")
    axes[1,0].set_xlabel("Time (s)")
    axes[1,0].set_title("(c) SOC trajectory")

    axes[1,1].set_ylabel("K-step residual eK (mV)")
    axes[1,1].set_xlabel("Time (s)")
    axes[1,1].set_title("(d) K-step residual")
    axes[1,1].axhline(0.0, color="k", linewidth=0.5)

    for ax in axes.flat:
        ax.legend(loc="best", fontsize=8, framealpha=0.9)

    plt.tight_layout()
    fn = os.path.join(FIG, f"fig_{cell_tag}_{scen}.png")
    plt.savefig(fn, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fn


def fig_r_correction(rows, cell_tag, cell_name, scen):
    """R-correction + temperature panel (2 panels side-by-side)."""
    gb = group(rows, ["method"])
    kr_trials = gb.get(("KR",), [])
    if not kr_trials:
        return None
    r_kr = median_trial(kr_trials)

    fig, (a1, a2) = plt.subplots(1, 2, figsize=(11.5, 3.6))
    fig.suptitle(f"{cell_tag} — {cell_name}   {SCEN_TITLE[scen]}",
                 fontsize=10, fontweight="bold", y=1.0)

    t = r_kr["t"]
    # R correction
    a1.plot(t, r_kr["R_corr"]*1000, color=COLOUR["KR"], linewidth=1.8, label="R correction")
    a1.fill_between(t, 0, r_kr["R_corr"]*1000, color=COLOUR["KR"], alpha=0.25)
    a1.axhline(0, color="k", linewidth=0.5)
    a1.set_xlabel("Time (s)")
    a1.set_ylabel("R correction (mA)")
    a1.set_title("(a) R-step correction signal")
    a1.legend(loc="best", fontsize=9)

    # Temperature
    for method in ["CC-CV", "K-only", "KR"]:
        trs = gb.get((method,), [])
        if not trs:
            continue
        r = median_trial(trs)
        a2.plot(r["t"], r["T"], color=COLOUR[method], linewidth=WIDTH[method],
                label=LABEL[method])
    a2.set_xlabel("Time (s)")
    a2.set_ylabel("Cell temperature (°C)")
    a2.set_title("(b) Temperature rise (thermal safety)")
    a2.legend(loc="best", fontsize=9)

    plt.tight_layout()
    fn = os.path.join(FIG, f"fig_{cell_tag}_{scen}_R.png")
    plt.savefig(fn, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fn


def fig_summary_bars(all_rows):
    """Summary bar chart: overshoot + RMSE, each scenario × 4 methods, cells averaged."""
    gb = group(all_rows, ["scenario", "method"])
    scenarios = ["S1", "S3", "S5"]
    methods = ["CC-CV", "MPC", "K-only", "KR"]

    fig, (a1, a2) = plt.subplots(1, 2, figsize=(12, 4.2))
    x = np.arange(len(scenarios))
    w = 0.2
    for i, m in enumerate(methods):
        means_ov, stds_ov = [], []
        means_r,  stds_r  = [], []
        for s in scenarios:
            trs = gb.get((s, m), [])
            ovs = np.array([r["ov_mV"] for r in trs])
            rms = np.array([r["rmse_V_mV"] for r in trs])
            means_ov.append(ovs.mean()); stds_ov.append(ovs.std())
            means_r.append(rms.mean()); stds_r.append(rms.std())
        a1.bar(x + i*w - 1.5*w, means_ov, w, yerr=stds_ov,
               color=COLOUR[m], label=LABEL[m], capsize=3, edgecolor="black", linewidth=0.5)
        a2.bar(x + i*w - 1.5*w, means_r,  w, yerr=stds_r,
               color=COLOUR[m], label=LABEL[m], capsize=3, edgecolor="black", linewidth=0.5)

    a1.set_xticks(x); a1.set_xticklabels(scenarios)
    a1.set_ylabel("Voltage overshoot (mV)")
    a1.set_title("(a) Overshoot beyond V_max (lower = better)")
    a1.legend(loc="best", fontsize=9)

    a2.set_xticks(x); a2.set_xticklabels(scenarios)
    a2.set_ylabel("Voltage tracking RMSE (mV)")
    a2.set_title("(b) Tracking RMSE vs V_max setpoint")
    a2.legend(loc="best", fontsize=9)

    plt.tight_layout()
    fn = os.path.join(FIG, "fig_summary_bars.png")
    plt.savefig(fn, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fn


def fig_cross_dataset(all_rows):
    """Cross-dataset generalization: overshoot reduction per (cell, scenario)."""
    gb = group(all_rows, ["cell_tag", "scenario", "method"])
    cells = sorted({r["cell_tag"] for r in all_rows})
    scenarios = ["S1", "S3", "S5"]

    fig, ax = plt.subplots(1, 1, figsize=(10, 4.2))
    x = np.arange(len(cells))
    w = 0.25
    for i, s in enumerate(scenarios):
        reds = []
        for cell in cells:
            ko = np.mean([r["ov_mV"] for r in gb.get((cell, s, "K-only"), [])])
            kr = np.mean([r["ov_mV"] for r in gb.get((cell, s, "KR"), [])])
            reds.append(100.0 * (1 - kr/ko) if ko > 0 else 0.0)
        ax.bar(x + i*w - w, reds, w, label=s, edgecolor="black", linewidth=0.5)

    ax.set_xticks(x)
    ax.set_xticklabels([c.replace("_", " ") for c in cells], rotation=15)
    ax.set_ylabel("Overshoot reduction K-R vs K-only (%)")
    ax.set_title("Cross-dataset generalisation: K-R overshoot reduction across 4 NASA RW usage profiles")
    ax.legend(title="Scenario", loc="best", fontsize=9)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_cross_dataset.png")
    plt.savefig(fn, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fn


def fig_autocorr(all_rows):
    """Autocorrelation of eK: proves residual carries structure (not white noise)."""
    gb = group(all_rows, ["cell_tag", "scenario", "method"])
    cells = sorted({r["cell_tag"] for r in all_rows})
    scenarios = ["S1", "S3", "S5"]

    fig, axes = plt.subplots(1, 3, figsize=(13, 4))
    max_lag = 60
    for ax, s in zip(axes, scenarios):
        for cell in cells:
            trs = gb.get((cell, s, "K-only"), [])
            if not trs:
                continue
            r = median_trial(trs)
            eK = r["eK"].astype(np.float64)
            mask = np.abs(eK) > 1e-6
            eK = eK[mask]
            if len(eK) < max_lag + 5:
                continue
            eK = eK - eK.mean()
            denom = np.sum(eK * eK)
            if denom == 0:
                continue
            acf = [np.sum(eK[:-lag] * eK[lag:]) / denom for lag in range(1, max_lag + 1)]
            ax.plot(range(1, max_lag + 1), acf, linewidth=1.3, label=cell.split("_")[0])
        ax.axhline(0.0, color="k", linewidth=0.5)
        ax.axhline(0.1, color="gray", linewidth=0.5, linestyle="--",
                   label="white-noise band")
        ax.axhline(-0.1, color="gray", linewidth=0.5, linestyle="--")
        ax.set_xlabel("Lag (time steps)")
        ax.set_ylabel("Autocorrelation of eK")
        ax.set_title(f"{s}: {SCEN_TITLE[s].split(':')[1].strip().split(',')[0]}")
        ax.legend(fontsize=8, loc="best")
    plt.suptitle("Residual autocorrelation — proof of exploitable structure in the K-step error",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_autocorr.png")
    plt.savefig(fn, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fn


# ─────────────────────────────────────────────────────────────────────────────
def main():
    rows = load()
    print(f"Loaded {len(rows)} trials")

    gb = group(rows, ["cell_tag", "scenario"])
    produced = []
    for (cell, scen), trial_rows in sorted(gb.items()):
        fn1 = fig_four_panel(trial_rows, cell, trial_rows[0]["cell_name"], scen)
        fn2 = fig_r_correction(trial_rows, cell, trial_rows[0]["cell_name"], scen)
        if fn1: produced.append(fn1)
        if fn2: produced.append(fn2)
        print(f"  {cell} {scen}: done")

    produced.append(fig_summary_bars(rows))
    produced.append(fig_cross_dataset(rows))
    produced.append(fig_autocorr(rows))

    print(f"\nProduced {len(produced)} figures in {FIG}")
    for f in produced:
        print(f"  {os.path.basename(f)}  ({os.path.getsize(f)/1024:.0f} KB)")


if __name__ == "__main__":
    main()
