"""
Three stress experiments that reviewers will ask about:

  (a) MISMATCH sweep: mm ∈ {0.10, 0.25, 0.50}. Stresses the K-step believed-R
      vs plant-R gap. Proves K-R benefit scales with mismatch severity.

  (b) NOISE robustness: multiply sensor σ by {1×, 2×, 4×}. Shows K-R does not
      require clean simulation.

  (c) DRIFT test: during a single charge cycle, gradually increase plant R_int
      (linear ramp +30% over charge) while controller parameters stay FIXED.
      This is the "no re-identification" claim.

Two cells (RW9, RW20), three scenarios (S1/S3/S5), two methods (K-only, KR),
8 trials per config for stress, reduced to keep runtime under the timeout.
"""
import os, sys, time, pickle, gzip
import numpy as np

from nasa_loader import calibrate_cell
from simulator import (run_trial, scenario_configs, SIGMA_V, SIGMA_I, SIGMA_T,
                       _init_kr_model_params)
from plant_controllers import (PlantParams, PlantState, plant_step, ocv_of,
                               ECMObserver, controller_K, KR_Residual,
                               controller_cccv, arrhenius_R, soh_scale_R)
import simulator as _sim

OUT = "/home/claude/kr_sim/results"
os.makedirs(OUT, exist_ok=True)

CELLS = [
    ("DS1_Uniform_ChgDis", "RW9",
     "/home/claude/work/ds1/Battery_Uniform_Distribution_Charge_Discharge_DataSet_2Post/data/Matlab/RW9.mat"),
    ("DS5_Skewed_High_RT", "RW20",
     "/home/claude/work/ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat"),
]
N_TRIALS = 8


# ─────────────────────────────────────────────────────────────────────────────
def _run(cell, cfg, method, seed, mm_override=None, noise_scale=1.0,
         drift_pct=0.0):
    """Extended single-trial runner with stress knobs."""
    rng = np.random.default_rng(seed)
    # Plant
    R_plant = float(np.interp(0.5, cell.rint_vs_soc[:, 0], cell.rint_vs_soc[:, 1]))
    p = PlantParams(Q_nom_Ah=cell.soh_baseline, ocv_table=cell.ocv_table,
                    R_int_25C=R_plant, T_amb=cfg.T_amb, SOH=cfg.SOH)
    # Controller's belief (mismatch)
    mm = cfg.mismatch if mm_override is None else mm_override
    sc = (1.0 - mm)
    ocv_model = cell.ocv_table.copy()
    R_total_plant = R_plant + 0.006 + 0.008
    R_model = R_total_plant * sc
    obs = ECMObserver(SOC=cfg.SOC_start, R_int_m=R_plant*sc,
                      R1_m=0.006*sc, R2_m=0.008*sc, Q_nom_Ah=cell.soh_baseline)
    kr = KR_Residual(rng=rng) if method == "KR" else None
    x = PlantState(SOC=cfg.SOC_start, T_C=cfg.T_amb)
    V_t = ocv_of(p, x.SOC)
    I_applied = 0.0; I_prev = 0.0

    t_end = cfg.sim_time_s
    N = int(t_end / cfg.dt) + 1
    t_arr, V_arr, SOC_arr, R_arr, I_arr = [], [], [], [], []

    sV = SIGMA_V * noise_scale
    sI = SIGMA_I * noise_scale
    sT = SIGMA_T * noise_scale
    sSOC = 0.005 * noise_scale

    # Drift: plant R_int scales from 1.0 to (1 + drift_pct) over sim time
    for k in range(N):
        t = k * cfg.dt
        # Apply drift to plant (only plant; controller's belief stays fixed — the point
        # of this experiment)
        if drift_pct > 0:
            drift_factor = 1.0 + drift_pct * (t / t_end)
            # Mutate the plant R_int_25C field on the fly (affects arrhenius scaling too)
            # We temporarily tweak for the plant_step call by using a scaled copy
            p_step = PlantParams(Q_nom_Ah=p.Q_nom_Ah, ocv_table=p.ocv_table,
                                 R_int_25C=R_plant * drift_factor,
                                 R1=p.R1*drift_factor, R2=p.R2*drift_factor,
                                 T_amb=p.T_amb, SOH=p.SOH)
        else:
            p_step = p

        V_meas = V_t + rng.normal(0, sV)
        I_meas = I_applied + rng.normal(0, sI)
        T_meas = x.T_C + rng.normal(0, sT)
        SOC_est = x.SOC + rng.normal(0, sSOC)

        if method == "CC-CV":
            I_cmd = controller_cccv(cfg.V_max, cfg.I_cc, I_meas, V_meas, T_meas,
                                    SOC_est, cfg.SOH, cfg.Imax)
            eK = 0.0; R_corr = 0.0
        elif method == "K-only":
            I_cmd, _, _ = controller_K(cfg.V_max, cfg.I_cc, I_meas, V_meas,
                                       T_meas, SOC_est, cfg.SOH, cfg.Imax,
                                       ocv_model, R_model, obs=obs)
            eK = V_meas - obs.predict_V(I_prev, ocv_model)
            R_corr = 0.0
        elif method == "KR":
            V_pred = obs.predict_V(I_prev, ocv_model)
            eK = V_meas - V_pred
            f = kr.features(eK, T_meas, SOC_est, cfg.SOH, t)
            R_corr = kr.predict(f)
            I_K, _, _ = controller_K(cfg.V_max, cfg.I_cc, I_meas, V_meas, T_meas,
                                     SOC_est, cfg.SOH, cfg.Imax,
                                     ocv_model, R_model, obs=obs)
            I_cmd = float(np.clip(I_K + R_corr, 0.0, cfg.Imax))
            kr.record_and_maybe_train(f, eK, t, obs.R_int_m)
        else:
            raise ValueError(method)

        I_prev = I_applied
        I_applied = I_cmd
        x, V_t = plant_step(p_step, x, I_applied, cfg.dt, T_amb=cfg.T_amb)
        obs.propagate(I_applied, cfg.dt)

        t_arr.append(t); V_arr.append(V_t); SOC_arr.append(x.SOC)
        R_arr.append(R_corr); I_arr.append(I_applied)

        if (x.SOC > cfg.SOC_target * 0.99 and abs(I_cmd) < 0.05 * cfg.I_cc
                and V_meas > cfg.V_max - 0.01):
            break

    V_arr, SOC_arr, t_arr = np.array(V_arr), np.array(SOC_arr), np.array(t_arr)
    R_arr = np.array(R_arr)
    idx = np.where(SOC_arr >= cfg.SOC_target)[0]
    t80 = float(t_arr[idx[0]]) if len(idx) else float(cfg.sim_time_s)
    ov = float(max(0.0, V_arr.max() - cfg.V_max) * 1000.0)
    rmse = float(np.sqrt(np.mean((V_arr - cfg.V_max)**2)) * 1000.0)
    return {"t80": t80, "ov_mV": ov, "rmse_V_mV": rmse,
            "R_rms_mA": float(np.sqrt(np.mean(R_arr**2))*1000.0) if len(R_arr) else 0.0}


# ─────────────────────────────────────────────────────────────────────────────
def run_mismatch_sweep():
    rows = []
    mms = [0.10, 0.25, 0.50]
    for cell_tag, cell_name, path in CELLS:
        cell = calibrate_cell(path, cell_tag, cell_name)
        cfgs = scenario_configs(cell)
        for sname in ["S1", "S3", "S5"]:
            cfg = cfgs[sname]
            for mm in mms:
                for method in ["K-only", "KR"]:
                    for n in range(N_TRIALS):
                        seed = 5000 + hash((cell_name, sname, mm, method, n)) % (2**31)
                        r = _run(cell, cfg, method, seed, mm_override=mm)
                        r.update({"exp": "mismatch", "cell": cell_tag, "cell_name": cell_name,
                                  "scenario": sname, "mm": mm, "method": method})
                        rows.append(r)
            print(f"  mismatch {cell_tag} {sname}: done")
    return rows


def run_noise_robustness():
    rows = []
    scales = [1.0, 2.0, 4.0]
    for cell_tag, cell_name, path in CELLS:
        cell = calibrate_cell(path, cell_tag, cell_name)
        cfgs = scenario_configs(cell)
        for sname in ["S1", "S3", "S5"]:
            cfg = cfgs[sname]
            for s in scales:
                for method in ["K-only", "KR"]:
                    for n in range(N_TRIALS):
                        seed = 6000 + hash((cell_name, sname, s, method, n)) % (2**31)
                        r = _run(cell, cfg, method, seed, noise_scale=s)
                        r.update({"exp": "noise", "cell": cell_tag, "cell_name": cell_name,
                                  "scenario": sname, "noise_scale": s, "method": method})
                        rows.append(r)
            print(f"  noise {cell_tag} {sname}: done")
    return rows


def run_drift_test():
    """Plant R_int ramps +30% during the charge; controller is FIXED (no re-ID)."""
    rows = []
    drifts = [0.0, 0.15, 0.30]
    for cell_tag, cell_name, path in CELLS:
        cell = calibrate_cell(path, cell_tag, cell_name)
        cfgs = scenario_configs(cell)
        # Only standard-temperature scenarios are meaningful for this test (S1, S5)
        for sname in ["S1", "S5"]:
            cfg = cfgs[sname]
            for d in drifts:
                for method in ["K-only", "KR"]:
                    for n in range(N_TRIALS):
                        seed = 7000 + hash((cell_name, sname, d, method, n)) % (2**31)
                        r = _run(cell, cfg, method, seed, drift_pct=d)
                        r.update({"exp": "drift", "cell": cell_tag, "cell_name": cell_name,
                                  "scenario": sname, "drift_pct": d, "method": method})
                        rows.append(r)
            print(f"  drift {cell_tag} {sname}: done")
    return rows


# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    which = sys.argv[1] if len(sys.argv) > 1 else "all"
    all_rows = []
    t0 = time.time()
    if which in ("mismatch", "all"):
        print("=== MISMATCH SWEEP ===")
        all_rows += run_mismatch_sweep()
    if which in ("noise", "all"):
        print("=== NOISE ROBUSTNESS ===")
        all_rows += run_noise_robustness()
    if which in ("drift", "all"):
        print("=== DRIFT TEST ===")
        all_rows += run_drift_test()

    suffix = "_" + which if which != "all" else ""
    fn = os.path.join(OUT, f"stress{suffix}.pkl.gz")
    with gzip.open(fn, "wb") as f:
        pickle.dump(all_rows, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"\nSaved {len(all_rows)} rows to {fn} in {time.time()-t0:.1f}s")
