"""
Ablation study to isolate which K-R design choices carry the benefit.

Variants:
  - KR_full      : 8 features, 16-unit ReLU reservoir + ridge LS        (baseline)
  - KR_linear    : 8 features, no reservoir (direct ridge LS on features)
  - KR_no_cross  : 6 features, drop cross terms eK·SOC and T·eK
  - KR_no_SOH    : 7 features, drop SOH feature
  - K_only       : no residual correction at all                        (baseline)

Metric: overshoot (mV) per (cell × scenario), N_TRIALS=15 each.
"""
import os, sys, time, csv, pickle
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

sys.path.insert(0, "/home/claude/kr_sim")
from nasa_loader import calibrate_cell
from simulator import SimConfig, scenario_configs, run_trial, METHODS
from plant_controllers import (PlantParams, PlantState, plant_step, ocv_of,
                               controller_cccv, controller_mpc, controller_K,
                               KR_Residual, _clip_current)

OUT = "/home/claude/kr_sim/results"
FIG = os.path.join(OUT, "figures")

# ─────────────────────────────────────────────────────────────────────────────
class KR_Linear(KR_Residual):
    """Linear residual corrector: no ReLU reservoir, just ridge LS on features."""
    def predict(self, f):
        r = float(f @ self.w2_lin) if hasattr(self, "w2_lin") else 0.0
        return float(np.clip(r, -self.R_max, self.R_max))

    def record_and_maybe_train(self, f, eK, t, model_Rint):
        target = -eK / max(model_Rint, 1e-3)
        target = float(np.clip(target, -self.R_max, self.R_max))
        self.buffer.append((f.copy(), target))
        if len(self.buffer) > self.buf_N:
            self.buffer = self.buffer[-self.buf_N:]
        if (t - self.t_last_train) >= self.retrain_every_s and len(self.buffer) >= 20:
            F = np.stack([b[0] for b in self.buffer])   # (N, d)
            y = np.array([b[1] for b in self.buffer])
            A = F.T @ F + self.lam * np.eye(self.d)
            b = F.T @ y
            try:
                self.w2_lin = np.linalg.solve(A, b)
            except np.linalg.LinAlgError:
                self.w2_lin = np.linalg.lstsq(A, b, rcond=None)[0]
            self.t_last_train = t


class KR_FeatureAblation(KR_Residual):
    """KR with subset of features. Overrides `features()` to zero out dropped ones
    (keeps dim=8 so reservoir W1 stays the same size)."""
    def __init__(self, rng, drop_mask, **kw):
        super().__init__(rng, **kw)
        self.drop_mask = np.array(drop_mask, dtype=np.float64)  # 1 = keep, 0 = drop

    def features(self, eK, T, SOC, SOH, t):
        f = super().features(eK, T, SOC, SOH, t)
        return f * self.drop_mask


def _run_trial_with_residual(cell, cfg, residual_obj_factory, seed, method_name):
    """Re-implement the K-R path but with a custom residual-object factory.
    Returns (ov_mV, rmse_mV, t80)."""
    from simulator import SIGMA_V, SIGMA_I, SIGMA_T, _init_kr_model_params
    rng = np.random.default_rng(seed)

    R_plant_25C = float(np.interp(0.5, cell.rint_vs_soc[:, 0], cell.rint_vs_soc[:, 1]))
    p = PlantParams(
        Q_nom_Ah=cell.soh_baseline, ocv_table=cell.ocv_table,
        R_int_25C=R_plant_25C, T_amb=cfg.T_amb, SOH=cfg.SOH)
    ocv_model, R_model, _ = _init_kr_model_params(cell, cfg)

    kr = residual_obj_factory(rng)

    x = PlantState(SOC=cfg.SOC_start, T_C=cfg.T_amb)
    V_t = ocv_of(p, x.SOC)
    I_applied = 0.0
    I_prev_applied = 0.0
    x_prev_SOC = x.SOC

    N = int(cfg.sim_time_s / cfg.dt) + 1
    V_trace = np.zeros(N); t80 = None; T_trace = np.zeros(N)
    SOC_trace = np.zeros(N)
    t_now = 0.0
    for k in range(N):
        t_now = k * cfg.dt
        V_meas = V_t + rng.normal(0, SIGMA_V)
        I_meas = I_applied + rng.normal(0, SIGMA_I)
        T_meas = x.T_C + rng.normal(0, SIGMA_T)
        SOC_est = x.SOC + rng.normal(0, 0.005)

        OCV_prev = float(np.interp(x_prev_SOC, ocv_model[:, 0], ocv_model[:, 1]))
        V_pred = OCV_prev + I_prev_applied * R_model
        eK = V_meas - V_pred

        I_K, _, _ = controller_K(cfg.V_max, cfg.I_cc, I_meas, V_meas, T_meas,
                                 SOC_est, cfg.SOH, cfg.Imax, ocv_model, R_model)
        f = kr.features(eK, T_meas, SOC_est, cfg.SOH, t_now)
        R_corr = kr.predict(f)
        I_cmd = _clip_current(I_K + R_corr, cfg.Imax)
        kr.record_and_maybe_train(f, eK, t_now, R_model)

        I_applied = I_cmd
        x, V_t = plant_step(p, x, I_applied, cfg.dt, T_amb=cfg.T_amb)
        V_trace[k] = V_t
        T_trace[k] = x.T_C
        SOC_trace[k] = x.SOC
        if t80 is None and x.SOC >= cfg.SOC_target:
            t80 = t_now

        I_prev_applied = I_applied
        x_prev_SOC = x.SOC
        if (x.SOC > cfg.SOC_target * 0.99 and abs(I_cmd) < 0.05 * cfg.I_cc
            and V_meas > cfg.V_max - 0.01):
            V_trace = V_trace[:k+1]; T_trace = T_trace[:k+1]; SOC_trace = SOC_trace[:k+1]
            break

    ov = float(max(0.0, V_trace.max() - cfg.V_max) * 1000.0)
    rmse = float(np.sqrt(np.mean((V_trace - cfg.V_max)**2)) * 1000.0)
    if t80 is None:
        t80 = cfg.sim_time_s
    return ov, rmse, t80


VARIANTS = {
    "KR_full"      : lambda rng: KR_Residual(rng),
    "KR_linear"    : lambda rng: KR_Linear(rng),
    # feature layout: [eK, eK², T, SOC, eK·SOC, SOH, T·eK, dT/dt]
    "KR_no_cross"  : lambda rng: KR_FeatureAblation(rng, drop_mask=[1,1,1,1,0,1,0,1]),
    "KR_no_SOH"    : lambda rng: KR_FeatureAblation(rng, drop_mask=[1,1,1,1,1,0,1,1]),
}

CELLS = [
    ("DS1_Uniform_ChgDis",   "RW9",
     "/home/claude/work/ds1/Battery_Uniform_Distribution_Charge_Discharge_DataSet_2Post/data/Matlab/RW9.mat"),
    ("DS2_Uniform_DisRT",    "RW3",
     "/home/claude/work/ds2/Battery_Uniform_Distribution_Discharge_Room_Temp_DataSet_2Post/data/Matlab/RW3.mat"),
    ("DS5_Skewed_High_RT",   "RW20",
     "/home/claude/work/ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat"),
    ("DS7_Skewed_Low_RT",    "RW13",
     "/home/claude/work/ds7/RW_Skewed_Low_Room_Temp_DataSet_2Post/data/Matlab/RW13.mat"),
]
N_TRIALS = 6


def run_ablations():
    t0 = time.time()
    rows = []
    for cell_tag, cell_name, path in CELLS:
        print(f"\n[ABL]  {cell_name}")
        cell = calibrate_cell(path, cell_tag, cell_name)
        cfgs = scenario_configs(cell)
        for scen in ["S1", "S3", "S5"]:
            cfg = cfgs[scen]
            # Also run K-only baseline using the standard run_trial (no residual)
            ov_k, rmse_k, t80_k = [], [], []
            for tr in range(N_TRIALS):
                seed = 202504 + hash((cell_name, scen, "K-only", tr)) % (2**30)
                r = run_trial(cell, cfg, "K-only", seed)
                ov_k.append(r.ov); rmse_k.append(r.rmse_V); t80_k.append(r.t80)
            rows.append(dict(cell=cell_tag, cell_name=cell_name, scenario=scen,
                             variant="K_only",
                             ov_mean=float(np.mean(ov_k)), ov_std=float(np.std(ov_k)),
                             rmse_mean=float(np.mean(rmse_k)), rmse_std=float(np.std(rmse_k)),
                             t80_mean=float(np.mean(t80_k))))
            print(f"  {scen}  K_only      OV={np.mean(ov_k):6.2f}±{np.std(ov_k):4.2f} mV")

            for var_name, factory in VARIANTS.items():
                ov_v, rmse_v, t80_v = [], [], []
                for tr in range(N_TRIALS):
                    seed = 202504 + hash((cell_name, scen, var_name, tr)) % (2**30)
                    ov, rmse, t80 = _run_trial_with_residual(cell, cfg, factory, seed, var_name)
                    ov_v.append(ov); rmse_v.append(rmse); t80_v.append(t80)
                rows.append(dict(cell=cell_tag, cell_name=cell_name, scenario=scen,
                                 variant=var_name,
                                 ov_mean=float(np.mean(ov_v)), ov_std=float(np.std(ov_v)),
                                 rmse_mean=float(np.mean(rmse_v)), rmse_std=float(np.std(rmse_v)),
                                 t80_mean=float(np.mean(t80_v))))
                print(f"  {scen}  {var_name:12s}  OV={np.mean(ov_v):6.2f}±{np.std(ov_v):4.2f} mV")

    print(f"\nAblation grid done in {time.time()-t0:.1f} s")

    # Save CSV
    fn = os.path.join(OUT, "ablation.csv")
    with open(fn, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader(); w.writerows(rows)
    print(f"wrote {fn}")
    with open("/tmp/ablation.pkl", "wb") as f:
        pickle.dump(rows, f)
    return rows


def make_ablation_figure(rows):
    # One bar chart: overshoot per (scenario × variant), averaged over 4 cells
    scenarios = ["S1", "S3", "S5"]
    variants = ["K_only", "KR_no_SOH", "KR_no_cross", "KR_linear", "KR_full"]
    nice = {"K_only": "K-only", "KR_no_SOH": "K-R (no SOH feat.)",
            "KR_no_cross": "K-R (no cross terms)",
            "KR_linear": "K-R (linear, no reservoir)", "KR_full": "K-R ★ (full)"}
    colors = {"K_only": "#e6553f", "KR_no_SOH": "#c98cdd",
              "KR_no_cross": "#ffa94d", "KR_linear": "#4dabf7", "KR_full": "#2f9e44"}

    fig, axes = plt.subplots(1, 3, figsize=(14, 4.2), sharey=False)
    for ax, scen in zip(axes, scenarios):
        scen_rows = [r for r in rows if r["scenario"] == scen]
        vals = {v: [] for v in variants}
        for r in scen_rows:
            if r["variant"] in vals:
                vals[r["variant"]].append(r["ov_mean"])
        means = [np.mean(vals[v]) for v in variants]
        stds  = [np.std(vals[v])  for v in variants]
        x = np.arange(len(variants))
        bars = ax.bar(x, means, yerr=stds, capsize=4,
                      color=[colors[v] for v in variants],
                      edgecolor="black", linewidth=0.7)
        ax.set_xticks(x)
        ax.set_xticklabels([nice[v] for v in variants], rotation=25, ha="right", fontsize=9)
        ax.set_ylabel("Peak voltage overshoot (mV)")
        ax.set_title(f"Scenario {scen}")
        # annotate bars
        for b, m in zip(bars, means):
            ax.text(b.get_x() + b.get_width()/2, b.get_height()*1.02,
                    f"{m:.1f}", ha="center", va="bottom", fontsize=8)
    plt.suptitle("Ablation study: which K-R design choices carry the benefit?",
                 fontweight="bold", y=1.02)
    plt.tight_layout()
    fn = os.path.join(FIG, "fig_ablation.png")
    plt.savefig(fn, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fn


if __name__ == "__main__":
    rows = run_ablations()
    fn = make_ablation_figure(rows)
    print(f"wrote {fn}")
