"""
Analyse the 960-trial grid:
  - per (cell, scenario, method): mean ± std of t80, ov_mV, rmse_V_mV, eff, dT_max
  - per (cell, scenario): paired t-tests K-R vs CC-CV and K-R vs K-only, Cohen's d
  - per (cell, scenario): KR-specific metrics
      eK autocorrelation at lag 1, |R| RMS, MSE reduction (K-R vs K-only), orthogonality
  - saves summary.csv and ttests.csv and kr_metrics.csv
"""
import os, gzip, pickle, csv
import numpy as np
from scipy import stats


OUT = "/home/claude/kr_sim/results"


def load():
    with gzip.open(os.path.join(OUT, "grid.pkl.gz"), "rb") as f:
        return pickle.load(f)


def group_by(rows, keys):
    """Return dict of tuple(keys) -> list(rows)."""
    from collections import defaultdict
    out = defaultdict(list)
    for r in rows:
        out[tuple(r[k] for k in keys)].append(r)
    return dict(out)


def summarize(rows):
    """Compute mean±std of scalar metrics across trials."""
    metrics = ["t80", "ov_mV", "rmse_V_mV", "eff", "dT_max"]
    out = {}
    for m in metrics:
        vals = np.array([r[m] for r in rows])
        out[f"{m}_mean"] = float(np.mean(vals))
        out[f"{m}_std"]  = float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
    out["n"] = len(rows)
    return out


def paired_t(a, b):
    """Paired t-test on metric arrays; Cohen's d on the difference."""
    a = np.asarray(a); b = np.asarray(b)
    diff = a - b
    t, p = stats.ttest_rel(a, b)
    # Cohen's d for paired samples = mean(diff) / std(diff)
    d = float(np.mean(diff) / np.std(diff, ddof=1)) if np.std(diff, ddof=1) > 0 else 0.0
    return float(t), float(p), d


def autocorr_lag1(x):
    x = np.asarray(x, dtype=np.float64)
    x = x[np.isfinite(x)]
    if len(x) < 10:
        return 0.0
    x = x - x.mean()
    s = np.sum(x * x)
    if s == 0:
        return 0.0
    return float(np.sum(x[:-1] * x[1:]) / s)


def kr_specific(cell_tag, cell_name, scen, rows_by_method):
    """Compute KR-only metrics from trial traces."""
    # Take median trial for each method to compute trace-level metrics stably
    kr_rows = rows_by_method.get("KR", [])
    k_rows  = rows_by_method.get("K-only", [])
    if not kr_rows or not k_rows:
        return None

    # eK autocorr: average across trials (using K-only eK since it's the "pure" residual)
    autocorrs = []
    for r in k_rows:
        eK = r["eK"]
        # restrict to active charging phase (exclude tail padding where eK==0)
        mask = np.abs(eK) > 1e-6
        if np.sum(mask) > 30:
            autocorrs.append(autocorr_lag1(eK[mask]))
    eK_ac = float(np.median(autocorrs)) if autocorrs else 0.0

    # |R| RMS (A) — the KR correction magnitude
    rms_R = []
    for r in kr_rows:
        Rc = r["R_corr"]
        mask = np.abs(Rc) > 1e-9
        if np.sum(mask) > 30:
            rms_R.append(float(np.sqrt(np.mean(Rc[mask]**2))))
    R_rms_A = float(np.median(rms_R)) if rms_R else 0.0

    # MSE reduction: K-R vs K-only (on tracking MSE in mV²)
    mse_K  = np.array([r["rmse_V_mV"]**2 for r in k_rows])
    mse_KR = np.array([r["rmse_V_mV"]**2 for r in kr_rows])
    # take paired means (same seed ordering)
    mse_red = 1.0 - float(np.mean(mse_KR) / np.mean(mse_K)) if np.mean(mse_K) > 0 else 0.0

    # Orthogonality: cosine similarity between eK and R_corr (should be close to 1 — R correcting in direction of error)
    cos_sims = []
    for r in kr_rows:
        eK = r["eK"].astype(np.float64)
        Rc = r["R_corr"].astype(np.float64)
        mask = (np.abs(eK) > 1e-6) | (np.abs(Rc) > 1e-9)
        if np.sum(mask) > 30:
            a = eK[mask]; b = -Rc[mask]   # R corrects negative of eK (same direction as −eK)
            na = np.linalg.norm(a); nb = np.linalg.norm(b)
            if na > 0 and nb > 0:
                cos_sims.append(float(np.dot(a, b) / (na * nb)))
    orthogonality = float(np.median(cos_sims)) if cos_sims else 0.0

    # Overshoot reduction: K-R vs K-only (paired)
    ov_K  = np.array([r["ov_mV"] for r in k_rows])
    ov_KR = np.array([r["ov_mV"] for r in kr_rows])
    ov_red = 1.0 - float(np.mean(ov_KR) / np.mean(ov_K)) if np.mean(ov_K) > 0 else 0.0

    return {
        "eK_autocorr_lag1": eK_ac,
        "R_rms_mA":         R_rms_A * 1000.0,
        "mse_reduction":    mse_red,
        "orthogonality":    orthogonality,
        "ov_reduction":     ov_red,
    }


# ─────────────────────────────────────────────────────────────────────────────
def main():
    rows = load()
    print(f"Loaded {len(rows)} trials")

    # 1) Per (cell, scenario, method) summary
    gb = group_by(rows, ["cell_tag", "scenario", "method"])
    summary_rows = []
    for (cell, scen, method), trial_rows in sorted(gb.items()):
        s = summarize(trial_rows)
        s.update({"cell": cell, "scenario": scen, "method": method})
        summary_rows.append(s)

    fn = os.path.join(OUT, "summary.csv")
    with open(fn, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["cell", "scenario", "method", "n",
                                           "t80_mean", "t80_std",
                                           "ov_mV_mean", "ov_mV_std",
                                           "rmse_V_mV_mean", "rmse_V_mV_std",
                                           "eff_mean", "eff_std",
                                           "dT_max_mean", "dT_max_std"])
        w.writeheader(); w.writerows(summary_rows)
    print(f"wrote {fn}")

    # 2) Paired t-tests: K-R vs CC-CV and K-R vs K-only
    gb2 = group_by(rows, ["cell_tag", "scenario"])
    ttest_rows = []
    for (cell, scen), trial_rows in sorted(gb2.items()):
        by_m = group_by(trial_rows, ["method"])
        # align trials by trial index
        get = lambda m, key: np.array([r[key] for r in
                                       sorted(by_m[(m,)], key=lambda x: x["trial"])])
        for key, label in [("t80", "t80"), ("ov_mV", "OV"),
                           ("rmse_V_mV", "RMSE_V"), ("eff", "eff")]:
            for ref in ["CC-CV", "K-only"]:
                a = get("KR", key); b = get(ref, key)
                if len(a) == len(b) and len(a) > 1:
                    t, p, d = paired_t(a, b)
                    improvement = float(np.mean(a) - np.mean(b))  # negative = KR better (for t80, OV, RMSE)
                    rel = float(100.0 * improvement / (np.mean(b) if np.mean(b) != 0 else 1e-9))
                    ttest_rows.append({
                        "cell": cell, "scenario": scen, "metric": label,
                        "comparison": f"KR - {ref}",
                        "mean_diff": improvement,
                        "rel_pct":   rel,
                        "t":         t,
                        "p":         p,
                        "cohens_d":  d,
                        "sig": "***" if p < 0.001 else ("**" if p < 0.01 else ("*" if p < 0.05 else "ns")),
                    })
    fn = os.path.join(OUT, "ttests.csv")
    with open(fn, "w", newline="") as f:
        if ttest_rows:
            w = csv.DictWriter(f, fieldnames=list(ttest_rows[0].keys()))
            w.writeheader(); w.writerows(ttest_rows)
    print(f"wrote {fn}")

    # 3) KR-specific metrics
    kr_rows = []
    for (cell, scen), trial_rows in sorted(gb2.items()):
        by_m = group_by(trial_rows, ["method"])
        by_m_clean = {k[0]: v for k, v in by_m.items()}
        km = kr_specific(cell, "", scen, by_m_clean)
        if km is not None:
            km.update({"cell": cell, "scenario": scen})
            kr_rows.append(km)
    fn = os.path.join(OUT, "kr_metrics.csv")
    with open(fn, "w", newline="") as f:
        if kr_rows:
            w = csv.DictWriter(f, fieldnames=["cell", "scenario",
                                               "eK_autocorr_lag1", "R_rms_mA",
                                               "mse_reduction", "orthogonality",
                                               "ov_reduction"])
            w.writeheader(); w.writerows(kr_rows)
    print(f"wrote {fn}")

    # Print headline comparisons
    print("\n" + "="*80)
    print("HEADLINE: K-R vs K-only, overshoot (mV)")
    print("="*80)
    for (cell, scen), trial_rows in sorted(gb2.items()):
        by_m = group_by(trial_rows, ["method"])
        ko = np.array([r["ov_mV"] for r in by_m[("K-only",)]])
        kr = np.array([r["ov_mV"] for r in by_m[("KR",)]])
        red = 100.0 * (1 - np.mean(kr)/np.mean(ko)) if np.mean(ko) > 0 else 0
        print(f"  {cell:22s}  {scen}  K-only={np.mean(ko):6.1f}±{np.std(ko):4.1f}  "
              f"KR={np.mean(kr):6.1f}±{np.std(kr):4.1f}  reduction={red:5.1f}%")

    print("\n" + "="*80)
    print("HEADLINE: KR-specific metrics")
    print("="*80)
    print(f"  {'Cell':22s}  {'Scen':3s}  {'eK_AC':>7s}  {'|R|_RMS':>9s}  "
          f"{'MSE↓':>7s}  {'Orth':>6s}  {'OV↓':>7s}")
    for row in kr_rows:
        print(f"  {row['cell']:22s}  {row['scenario']:3s}  "
              f"{row['eK_autocorr_lag1']:7.3f}  "
              f"{row['R_rms_mA']:7.1f}mA  "
              f"{row['mse_reduction']*100:+6.1f}%  "
              f"{row['orthogonality']:+6.3f}  "
              f"{row['ov_reduction']*100:+6.1f}%")


if __name__ == "__main__":
    main()
