"""Run one cell at a time so we can chunk large grids under container time limits."""
import os, sys, time, pickle, gzip
import numpy as np
from nasa_loader import calibrate_cell
from simulator import run_trial, scenario_configs, METHODS, TrialResult

CELLS = [
    ("DS1_Uniform_ChgDis",   "RW9",
     "/home/claude/work/ds1/Battery_Uniform_Distribution_Charge_Discharge_DataSet_2Post/data/Matlab/RW9.mat"),
    ("DS2_Uniform_DisRT",    "RW3",
     "/home/claude/work/ds2/Battery_Uniform_Distribution_Discharge_Room_Temp_DataSet_2Post/data/Matlab/RW3.mat"),
    ("DS5_Skewed_High_RT",   "RW20",
     "/home/claude/work/ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat"),
    ("DS7_Skewed_Low_RT",    "RW13",
     "/home/claude/work/ds7/RW_Skewed_Low_Room_Temp_DataSet_2Post/data/Matlab/RW13.mat"),
]
N_TRIALS = 20
BASE_SEED = 1337
OUT = "/home/claude/kr_sim/results"
os.makedirs(OUT, exist_ok=True)


def _result_to_dict(r, ct, cn, s, m, n):
    return {"cell_tag": ct, "cell_name": cn, "scenario": s, "method": m, "trial": n,
            "t": r.t.astype(np.float32), "V": r.V.astype(np.float32),
            "I": r.I.astype(np.float32), "SOC": r.SOC.astype(np.float32),
            "T": r.T.astype(np.float32), "eK": r.eK.astype(np.float32),
            "R_corr": r.R_corr.astype(np.float32), "V_pred_K": r.V_pred_K.astype(np.float32),
            "t80": r.t80, "ov_mV": r.ov, "rmse_V_mV": r.rmse_V,
            "eff": r.eff, "dT_max": r.dT_max}


def main(cell_idx):
    cell_tag, cell_name, path = CELLS[cell_idx]
    fn = os.path.join(OUT, f"grid_{cell_tag}.pkl.gz")
    if os.path.exists(fn):
        print(f"SKIP (already exists): {fn}")
        return
    print(f"=== {cell_tag} — {cell_name} ===")
    t0 = time.time()
    cell = calibrate_cell(path, cell_tag, cell_name)
    cfgs = scenario_configs(cell)
    rows = []
    for sname in ["S1", "S3", "S5"]:
        cfg = cfgs[sname]
        for method in METHODS:
            tm = time.time()
            for trial in range(N_TRIALS):
                seed = BASE_SEED + hash((cell_name, sname, method, trial)) % (2**31)
                r = run_trial(cell, cfg, method, seed)
                rows.append(_result_to_dict(r, cell_tag, cell_name, sname, method, trial))
            print(f"  {sname}  {method:7s}  {N_TRIALS} trials  {time.time()-tm:4.1f}s")
    with gzip.open(fn, "wb") as f:
        pickle.dump(rows, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Saved {len(rows)} rows to {fn} ({os.path.getsize(fn)/1e6:.1f} MB) in {time.time()-t0:.1f}s")


if __name__ == "__main__":
    main(int(sys.argv[1]) if len(sys.argv) > 1 else 0)
