"""
Cross-cell (train/test-split) validation for the K-R residual learner.

Protocol:
  Split 1: train {DS1, DS2},  test {DS5, DS7}
  Split 2: train {DS5, DS7},  test {DS1, DS2}
  Split 3: train {DS1, DS5},  test {DS2, DS7}

For each split:
  1. Collect (feature, target) pairs from running K-only on the TRAIN cells
     across all 3 scenarios (S1, S3, S5), 10 trials each → ~180k samples.
  2. Solve one closed-form ridge regression for w2 on the pooled training data.
  3. Freeze w2 and deploy the K-R controller on the TEST cells, 20 MC trials
     per (cell × scenario). No online retraining — the reservoir weights are
     fixed at inference time.

This proves the claim that the residual structure is transferable across
unseen NASA cells, not a per-cell artifact of online learning.
"""
import os, pickle, gzip, time
import numpy as np

from nasa_loader import calibrate_cell
from simulator import run_trial, scenario_configs, METHODS, SIGMA_V, SIGMA_I, SIGMA_T
from plant_controllers import (PlantParams, PlantState, plant_step, ocv_of,
                               ECMObserver, controller_K, KR_Residual)
from simulator import _init_kr_model_params


CELLS = [
    ("DS1_Uniform_ChgDis",   "RW9",
     "/home/claude/work/ds1/Battery_Uniform_Distribution_Charge_Discharge_DataSet_2Post/data/Matlab/RW9.mat"),
    ("DS2_Uniform_DisRT",    "RW3",
     "/home/claude/work/ds2/Battery_Uniform_Distribution_Discharge_Room_Temp_DataSet_2Post/data/Matlab/RW3.mat"),
    ("DS5_Skewed_High_RT",   "RW20",
     "/home/claude/work/ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat"),
    ("DS7_Skewed_Low_RT",    "RW13",
     "/home/claude/work/ds7/RW_Skewed_Low_Room_Temp_DataSet_2Post/data/Matlab/RW13.mat"),
]
CELL_MAP = {c[0]: (c[1], c[2]) for c in CELLS}
OUT = "/home/claude/kr_sim/results"


def _harvest_samples(cell, cfg, seed, W1):
    """Run a K-only trial and collect (feature, target) pairs for reservoir training."""
    rng = np.random.default_rng(seed)
    R_plant = float(np.interp(0.5, cell.rint_vs_soc[:, 0], cell.rint_vs_soc[:, 1]))
    p = PlantParams(Q_nom_Ah=cell.soh_baseline, ocv_table=cell.ocv_table,
                    R_int_25C=R_plant, T_amb=cfg.T_amb, SOH=cfg.SOH)
    ocv_model, R_model, mdl = _init_kr_model_params(cell, cfg)
    sc = (1.0 - cfg.mismatch)
    obs = ECMObserver(SOC=cfg.SOC_start, R_int_m=mdl["R_int_plant"]*sc,
                      R1_m=0.006*sc, R2_m=0.008*sc, Q_nom_Ah=cell.soh_baseline)
    x = PlantState(SOC=cfg.SOC_start, T_C=cfg.T_amb)
    V_t = ocv_of(p, x.SOC)
    I_applied = 0.0; I_prev = 0.0
    T_prev = None
    feats = []; targs = []
    N = int(cfg.sim_time_s / cfg.dt) + 1
    for k in range(N):
        V_meas = V_t + rng.normal(0, SIGMA_V)
        T_meas = x.T_C + rng.normal(0, SIGMA_T)
        SOC_est = x.SOC + rng.normal(0, 0.005)
        # K-only controller
        I_cmd, _, _ = controller_K(cfg.V_max, cfg.I_cc, I_applied, V_meas, T_meas,
                                   SOC_est, cfg.SOH, cfg.Imax,
                                   ocv_model, R_model, obs=obs)
        V_pred = obs.predict_V(I_prev, ocv_model)
        eK = V_meas - V_pred
        dT_dt = 0.0 if T_prev is None else (T_meas - T_prev)
        T_prev = T_meas
        f = np.array([eK, eK*eK, T_meas, SOC_est, eK*SOC_est, cfg.SOH,
                       T_meas*eK, dT_dt])
        target = -eK / max(obs.R_int_m, 1e-3)
        target = np.clip(target, -0.5, 0.5)
        feats.append(f); targs.append(target)
        # step
        I_prev = I_applied
        I_applied = I_cmd
        x, V_t = plant_step(p, x, I_applied, cfg.dt, T_amb=cfg.T_amb)
        obs.propagate(I_applied, cfg.dt)
        # early stop
        if (x.SOC > cfg.SOC_target * 0.99 and abs(I_cmd) < 0.05 * cfg.I_cc
                and V_meas > cfg.V_max - 0.01):
            break
    return np.stack(feats), np.array(targs)


def train_reservoir(train_cells, W1, n_trials=10):
    """Pool samples from all train cells / scenarios / trials; solve closed-form ridge."""
    F_all, y_all = [], []
    for cell_tag in train_cells:
        cell_name, path = CELL_MAP[cell_tag]
        cell = calibrate_cell(path, cell_tag, cell_name)
        cfgs = scenario_configs(cell)
        for sname in ["S1", "S3", "S5"]:
            cfg = cfgs[sname]
            for n in range(n_trials):
                seed = 7007 + hash((cell_name, sname, n)) % (2**31)
                F, y = _harvest_samples(cell, cfg, seed, W1)
                F_all.append(F); y_all.append(y)
    F = np.concatenate(F_all, axis=0)
    y = np.concatenate(y_all, axis=0)
    print(f"  pooled training samples: {len(y):,}")
    H = np.maximum(F @ W1, 0.0)
    lam = 1e-2
    w2 = np.linalg.solve(H.T @ H + lam * np.eye(W1.shape[1]), H.T @ y)
    return w2


def _run_test_trial(cell, cfg, W1, w2_frozen, seed):
    """Run a single K-R trial with FROZEN w2 (no online retraining)."""
    rng = np.random.default_rng(seed)
    R_plant = float(np.interp(0.5, cell.rint_vs_soc[:, 0], cell.rint_vs_soc[:, 1]))
    p = PlantParams(Q_nom_Ah=cell.soh_baseline, ocv_table=cell.ocv_table,
                    R_int_25C=R_plant, T_amb=cfg.T_amb, SOH=cfg.SOH)
    ocv_model, R_model, mdl = _init_kr_model_params(cell, cfg)
    sc = (1.0 - cfg.mismatch)
    obs = ECMObserver(SOC=cfg.SOC_start, R_int_m=mdl["R_int_plant"]*sc,
                      R1_m=0.006*sc, R2_m=0.008*sc, Q_nom_Ah=cell.soh_baseline)
    kr = KR_Residual(rng=rng, preset_W1=W1, preset_w2=w2_frozen, freeze_w2=True)
    x = PlantState(SOC=cfg.SOC_start, T_C=cfg.T_amb)
    V_t = ocv_of(p, x.SOC)
    I_applied = 0.0; I_prev = 0.0
    t_arr = []; V_arr = []; SOC_arr = []; eK_arr = []; R_arr = []
    N = int(cfg.sim_time_s / cfg.dt) + 1
    for k in range(N):
        t = k * cfg.dt
        V_meas = V_t + rng.normal(0, SIGMA_V)
        T_meas = x.T_C + rng.normal(0, SIGMA_T)
        SOC_est = x.SOC + rng.normal(0, 0.005)
        # K-R forward pass (frozen w2)
        V_pred = obs.predict_V(I_prev, ocv_model)
        eK = V_meas - V_pred
        f = kr.features(eK, T_meas, SOC_est, cfg.SOH, t)
        R_corr = kr.predict(f)
        I_K, _, _ = controller_K(cfg.V_max, cfg.I_cc, I_applied, V_meas, T_meas,
                                 SOC_est, cfg.SOH, cfg.Imax,
                                 ocv_model, R_model, obs=obs)
        I_cmd = float(np.clip(I_K + R_corr, 0.0, cfg.Imax))
        # step plant & observer
        I_prev = I_applied
        I_applied = I_cmd
        x, V_t = plant_step(p, x, I_applied, cfg.dt, T_amb=cfg.T_amb)
        obs.propagate(I_applied, cfg.dt)
        t_arr.append(t); V_arr.append(V_t); SOC_arr.append(x.SOC)
        eK_arr.append(eK); R_arr.append(R_corr)
        if (x.SOC > cfg.SOC_target * 0.99 and abs(I_cmd) < 0.05 * cfg.I_cc
                and V_meas > cfg.V_max - 0.01):
            break
    t_arr, V_arr, SOC_arr = np.array(t_arr), np.array(V_arr), np.array(SOC_arr)
    eK_arr, R_arr = np.array(eK_arr), np.array(R_arr)
    idx_80 = np.where(SOC_arr >= cfg.SOC_target)[0]
    t80 = float(t_arr[idx_80[0]]) if len(idx_80) else float(cfg.sim_time_s)
    ov = float(max(0.0, V_arr.max() - cfg.V_max) * 1000.0)
    rmse_V = float(np.sqrt(np.mean((V_arr - cfg.V_max)**2)) * 1000.0)
    return {"t80": t80, "ov_mV": ov, "rmse_V_mV": rmse_V,
            "|R|_rms_mA": float(np.sqrt(np.mean(R_arr**2))*1000.0)}


def main(only_split=None):
    t0 = time.time()
    # Fixed random-init reservoir (same W1 used across all splits for reproducibility)
    rng = np.random.default_rng(12345)
    a = float(np.sqrt(6.0 / (8 + 16)))
    W1 = rng.uniform(-a, a, size=(8, 16))

    splits = [
        ("split1", ["DS1_Uniform_ChgDis", "DS2_Uniform_DisRT"],
                   ["DS5_Skewed_High_RT", "DS7_Skewed_Low_RT"]),
        ("split2", ["DS5_Skewed_High_RT", "DS7_Skewed_Low_RT"],
                   ["DS1_Uniform_ChgDis", "DS2_Uniform_DisRT"]),
        ("split3", ["DS1_Uniform_ChgDis", "DS5_Skewed_High_RT"],
                   ["DS2_Uniform_DisRT", "DS7_Skewed_Low_RT"]),
    ]
    all_rows = []
    for split_name, train, test in splits:
        if only_split is not None and split_name != only_split:
            continue
        # if per-split file already exists, skip
        fn_split = os.path.join(OUT, f"cross_cell_{split_name}.pkl.gz")
        if os.path.exists(fn_split):
            print(f"SKIP {split_name} (already have {fn_split})")
            continue
        print(f"\n=== {split_name}: train {train} -> test {test} ===")
        w2 = train_reservoir(train, W1, n_trials=4)
        split_rows = []
        for cell_tag in test:
            cell_name, path = CELL_MAP[cell_tag]
            cell = calibrate_cell(path, cell_tag, cell_name)
            cfgs = scenario_configs(cell)
            for sname in ["S1", "S3", "S5"]:
                cfg = cfgs[sname]
                for trial in range(10):
                    seed = 99999 + hash((split_name, cell_name, sname, trial)) % (2**31)
                    r = _run_test_trial(cell, cfg, W1, w2, seed)
                    r.update({"split": split_name, "cell_tag": cell_tag,
                              "cell_name": cell_name, "scenario": sname,
                              "trial": trial, "mode": "frozen_w2"})
                    split_rows.append(r)
                ov = np.array([r["ov_mV"] for r in split_rows
                               if r["cell_tag"] == cell_tag and r["scenario"] == sname])
                print(f"  {cell_tag}  {sname}  OV frozen-w2 = {ov.mean():6.1f} ± {ov.std():4.1f} mV")
        with gzip.open(fn_split, "wb") as f:
            pickle.dump(split_rows, f, protocol=pickle.HIGHEST_PROTOCOL)
        all_rows += split_rows

    # Merge all per-split files (existing + new)
    merged = []
    for split_name, _, _ in splits:
        fn_split = os.path.join(OUT, f"cross_cell_{split_name}.pkl.gz")
        if os.path.exists(fn_split):
            with gzip.open(fn_split, "rb") as f:
                merged += pickle.load(f)
    fn = os.path.join(OUT, "cross_cell.pkl.gz")
    with gzip.open(fn, "wb") as f:
        pickle.dump(merged, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"\nSaved {len(merged)} rows to {fn} in {time.time()-t0:.1f}s")


if __name__ == "__main__":
    import sys
    main(sys.argv[1] if len(sys.argv) > 1 else None)
