"""
Run the full Monte Carlo grid:
  4 cells × 3 scenarios × 4 methods × 20 trials = 960 simulations
Saves raw traces and metrics to a compressed pickle for downstream analysis.
"""
import os, time, pickle, gzip
import numpy as np
from nasa_loader import calibrate_cell
from simulator import run_trial, scenario_configs, METHODS, TrialResult


CELLS = [
    ("DS1_Uniform_ChgDis",   "RW9",
     "/home/claude/work/ds1/Battery_Uniform_Distribution_Charge_Discharge_DataSet_2Post/data/Matlab/RW9.mat"),
    ("DS2_Uniform_DisRT",    "RW3",
     "/home/claude/work/ds2/Battery_Uniform_Distribution_Discharge_Room_Temp_DataSet_2Post/data/Matlab/RW3.mat"),
    ("DS5_Skewed_High_RT",   "RW20",
     "/home/claude/work/ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat"),
    ("DS7_Skewed_Low_RT",    "RW13",
     "/home/claude/work/ds7/RW_Skewed_Low_Room_Temp_DataSet_2Post/data/Matlab/RW13.mat"),
]

N_TRIALS = 20
BASE_SEED = 1337

OUT = "/home/claude/kr_sim/results"
os.makedirs(OUT, exist_ok=True)


def _result_to_dict(r: TrialResult, cell_tag: str, cell_name: str,
                    scenario: str, method: str, trial: int) -> dict:
    """Compact representation that keeps traces (for figures) and summary metrics."""
    return {
        "cell_tag":   cell_tag,
        "cell_name":  cell_name,
        "scenario":   scenario,
        "method":     method,
        "trial":      trial,
        # full traces (for figures and autocorrelation)
        "t":     r.t.astype(np.float32),
        "V":     r.V.astype(np.float32),
        "I":     r.I.astype(np.float32),
        "SOC":   r.SOC.astype(np.float32),
        "T":     r.T.astype(np.float32),
        "eK":    r.eK.astype(np.float32),
        "R_corr": r.R_corr.astype(np.float32),
        "V_pred_K": r.V_pred_K.astype(np.float32),
        # scalar metrics
        "t80":    r.t80,
        "ov_mV":  r.ov,
        "rmse_V_mV": r.rmse_V,
        "eff":    r.eff,
        "dT_max": r.dT_max,
    }


def main():
    t_start = time.time()
    all_results = []   # list of dicts

    for cell_tag, cell_name, path in CELLS:
        print(f"\n=== {cell_tag} — {cell_name} ===")
        t0 = time.time()
        cell = calibrate_cell(path, cell_tag, cell_name)
        cfgs = scenario_configs(cell)
        print(f"  calibration: Q={cell.soh_baseline:.3f} Ah, "
              f"SOH range {cell.soh_trajectory[:,1].min()/cell.soh_baseline:.2f}–"
              f"{cell.soh_trajectory[:,1].max()/cell.soh_baseline:.2f}, "
              f"ocv@0.5 = {np.interp(0.5, cell.ocv_table[:,0], cell.ocv_table[:,1]):.3f} V, "
              f"R_int@0.5 = {np.interp(0.5, cell.rint_vs_soc[:,0], cell.rint_vs_soc[:,1])*1000:.1f} mΩ")

        for sname in ["S1", "S3", "S5"]:
            cfg = cfgs[sname]
            for method in METHODS:
                tm = time.time()
                for trial in range(N_TRIALS):
                    seed = BASE_SEED + hash((cell_name, sname, method, trial)) % (2**31)
                    r = run_trial(cell, cfg, method, seed)
                    all_results.append(_result_to_dict(
                        r, cell_tag, cell_name, sname, method, trial))
                print(f"  {sname}  {method:7s}  {N_TRIALS} trials in {time.time()-tm:4.1f}s")

        print(f"  cell done in {time.time()-t0:.1f}s")

    # Save everything
    fn = os.path.join(OUT, "grid.pkl.gz")
    with gzip.open(fn, "wb") as f:
        pickle.dump(all_results, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"\nSaved {len(all_results)} trials to {fn}  "
          f"({os.path.getsize(fn)/1e6:.1f} MB)")
    print(f"Total wall time: {time.time()-t_start:.1f}s")


if __name__ == "__main__":
    main()
