"""
Single-cycle charging simulation and Monte Carlo harness.

Each trial:
  - initialise plant at SOC=0.10 with cell-calibrated parameters
  - run controller at 1 Hz up to 2800 s (or until SOC ≥ 0.995 · stop_SOC)
  - inject Gaussian sensor noise on V (2 mV), I (50 mA), T (0.5 °C)
  - compute reviewer-expected metrics from the closed-loop traces
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Callable

from plant_controllers import (PlantParams, PlantState, plant_step, ocv_of,
                               arrhenius_R, soh_scale_R,
                               ECMObserver,
                               controller_cccv, controller_mpc,
                               controller_K, controller_KR, KR_Residual)


# Sensor noise (σ in each unit)
SIGMA_V = 0.002   # 2 mV
SIGMA_I = 0.050   # 50 mA
SIGMA_T = 0.5     # 0.5 °C


@dataclass
class SimConfig:
    scenario: str           # "S1" | "S3" | "S5"
    T_amb:    float         # ambient °C
    SOH:      float         # state of health
    Imax:     float         # max charge current (A)
    I_cc:     float         # CC-CV target current (A)
    V_max:    float = 4.20  # voltage limit
    SOC_start: float = 0.10
    SOC_target: float = 0.80   # metric t80 → time to 80% SOC
    sim_time_s: float = 2800.0
    dt: float = 1.0
    # controller model (what the K/KR controller believes, i.e. the "nominal" ECM)
    mismatch: float = 0.15     # fraction: plant = nominal · (1 + mismatch)


@dataclass
class TrialResult:
    t:     np.ndarray
    V:     np.ndarray          # plant terminal V
    V_meas: np.ndarray         # noisy measured V
    I:     np.ndarray          # applied current
    SOC:   np.ndarray
    T:     np.ndarray
    eK:    np.ndarray          # K-step residual (V) — 0 for non-K methods
    R_corr: np.ndarray         # R-step correction (A) — 0 for non-KR
    V_pred_K: np.ndarray       # K-step predicted V — 0 for non-K
    t80:   float               # time to reach SOC_target
    ov:    float               # max overshoot of V_max (V)
    rmse_V: float              # RMSE vs. V_target profile
    eff:    float              # energy efficiency (E_stored / E_input)
    dT_max: float              # max temperature rise (°C)
    method: str
    scenario: str


def _v_target_profile(t: np.ndarray, V_max: float) -> np.ndarray:
    """Reference voltage profile: rises to V_max during CC, stays there during CV.
    For the tracking-RMSE metric, we use a simple target trajectory of 4.20 V
    throughout (the CC-CV setpoint)."""
    return np.full_like(t, V_max)


def _init_kr_model_params(cell, cfg: SimConfig) -> Tuple[np.ndarray, float, Dict[str, float]]:
    """
    Return (ocv_table, R_total_Ω, model_dict) that the controller uses.
    The K step inverts STEADY-STATE impedance R_int + R1 + R2 (per §1.2).
    The controller believes a mismatched (under-estimated) version of this.
    """
    ocv_table = cell.ocv_table.copy()
    R_plant_25C = float(np.interp(0.5, cell.rint_vs_soc[:, 0],
                                       cell.rint_vs_soc[:, 1]))
    # Plant steady-state impedance = R_int + R1 + R2
    R_total_plant = R_plant_25C + 0.006 + 0.008
    # Controller's believed R is smaller by `mismatch` fraction
    R_model = R_total_plant * (1.0 - cfg.mismatch)
    return ocv_table, R_model, {"R_int_plant": R_plant_25C,
                                "R_total_plant": R_total_plant}


def run_trial(cell, cfg: SimConfig, method: str, seed: int) -> TrialResult:
    """Run one full closed-loop charging simulation."""
    rng = np.random.default_rng(seed)

    # ---- Plant parameters (TRUTH): uses Arrhenius scaling at T_amb implicitly
    # via arrhenius_R applied to x.T_C each step.
    R_plant_25C = float(np.interp(0.5, cell.rint_vs_soc[:, 0],
                                       cell.rint_vs_soc[:, 1]))
    p = PlantParams(
        Q_nom_Ah  = cell.soh_baseline,
        ocv_table = cell.ocv_table,
        R_int_25C = R_plant_25C,
        T_amb     = cfg.T_amb,
        SOH       = cfg.SOH,
    )
    ocv_model, R_model, mdl = _init_kr_model_params(cell, cfg)

    # ---- Controller-internal observer (2-RC ECM with MISMATCHED parameters)
    # The observer uses the controller's BELIEFS, which under-estimate the true
    # plant impedance by `mismatch`. R1/C1/R2/C2 are the same structural values
    # as the plant (datasheet-shape), scaled by (1 - mismatch) consistently.
    sc = (1.0 - cfg.mismatch)
    obs = ECMObserver(
        SOC      = cfg.SOC_start,
        V1       = 0.0,
        V2       = 0.0,
        R_int_m  = mdl["R_int_plant"] * sc,
        R1_m     = 0.006 * sc,
        C1_m     = 3500.0,
        R2_m     = 0.008 * sc,
        C2_m     = 60000.0,
        Q_nom_Ah = cell.soh_baseline,
    )

    # ---- Controller state
    kr = KR_Residual(rng=rng) if method == "KR" else None

    # ---- Initial conditions
    x = PlantState(SOC=cfg.SOC_start, T_C=cfg.T_amb)

    N = int(cfg.sim_time_s / cfg.dt) + 1
    t_arr   = np.zeros(N)
    V_arr   = np.zeros(N)
    Vm_arr  = np.zeros(N)
    I_arr   = np.zeros(N)
    SOC_arr = np.zeros(N)
    T_arr   = np.zeros(N)
    eK_arr  = np.zeros(N)
    R_arr   = np.zeros(N)
    Vp_arr  = np.zeros(N)

    # initial measurement (just for indexing)
    V_t = ocv_of(p, x.SOC)
    I_applied = 0.0
    I_prev_applied = 0.0   # previous step's applied current

    for k in range(N):
        t = k * cfg.dt

        # Noisy measurements (BMS sensing)
        V_meas = V_t + rng.normal(0, SIGMA_V)
        I_meas = I_applied + rng.normal(0, SIGMA_I)
        T_meas = x.T_C + rng.normal(0, SIGMA_T)
        # SOC estimation via Coulomb counting with small drift
        SOC_est = x.SOC + rng.normal(0, 0.005)

        # Controller decides next current command
        if method == "CC-CV":
            I_cmd = controller_cccv(cfg.V_max, cfg.I_cc, I_meas, V_meas, T_meas,
                                    SOC_est, cfg.SOH, cfg.Imax)
            eK = 0.0; R_corr = 0.0; V_pred_K = 0.0
        elif method == "MPC":
            I_cmd = controller_mpc(cfg.V_max, cfg.I_cc, I_meas, V_meas, T_meas,
                                   SOC_est, cfg.SOH, cfg.Imax, R_model)
            eK = 0.0; R_corr = 0.0; V_pred_K = 0.0
        elif method == "K-only":
            # K-step inversion using full observer state (OCV + V1 + V2)
            I_cmd, _, _ = controller_K(cfg.V_max, cfg.I_cc, I_meas, V_meas,
                                       T_meas, SOC_est, cfg.SOH, cfg.Imax,
                                       ocv_model, R_model, obs=obs)
            # Causal RC-aware eK: V_meas vs predicted under previous I and observer state
            V_pred_K = obs.predict_V(I_prev_applied, ocv_model)
            eK = V_meas - V_pred_K
            R_corr = 0.0
        elif method == "KR":
            I_cmd, eK, R_corr, V_pred_K = controller_KR(
                cfg.V_max, cfg.I_cc, I_meas, V_meas, T_meas,
                SOC_est, cfg.SOH, cfg.Imax, ocv_model, R_model, kr, t,
                I_prev=I_prev_applied, obs=obs)
        else:
            raise ValueError(method)

        # Apply current to plant (one step)
        I_applied = I_cmd
        x, V_t = plant_step(p, x, I_applied, cfg.dt, T_amb=cfg.T_amb)
        # Propagate observer's internal state using the SAME applied current
        obs.propagate(I_applied, cfg.dt)

        # Record
        t_arr[k]   = t
        V_arr[k]   = V_t
        Vm_arr[k]  = V_meas
        I_arr[k]   = I_applied
        SOC_arr[k] = x.SOC
        T_arr[k]   = x.T_C
        eK_arr[k]  = eK
        R_arr[k]   = R_corr
        Vp_arr[k]  = V_pred_K

        # Update previous-step applied current for next iteration's causal eK
        # (observer state V1, V2, SOC is propagated inside obs.propagate())
        I_prev_applied = I_applied

        # Termination: CV phase with |I| < 5% of I_cc AND SOC > 95% of target
        if (x.SOC > cfg.SOC_target * 0.99 and abs(I_cmd) < 0.05 * cfg.I_cc
            and V_meas > cfg.V_max - 0.01):
            # fill remaining with final values
            t_arr[k+1:]  = t_arr[k];  V_arr[k+1:]  = V_arr[k]
            Vm_arr[k+1:] = Vm_arr[k]; I_arr[k+1:]  = I_arr[k]
            SOC_arr[k+1:]= SOC_arr[k];T_arr[k+1:]  = T_arr[k]
            eK_arr[k+1:] = 0.0;       R_arr[k+1:]  = 0.0
            Vp_arr[k+1:] = Vp_arr[k]
            t_arr = t_arr[:k+1]; V_arr = V_arr[:k+1]; Vm_arr = Vm_arr[:k+1]
            I_arr = I_arr[:k+1]; SOC_arr = SOC_arr[:k+1]; T_arr = T_arr[:k+1]
            eK_arr = eK_arr[:k+1]; R_arr = R_arr[:k+1]; Vp_arr = Vp_arr[:k+1]
            break

    # ---- Metrics
    # t80: first time SOC ≥ SOC_target
    idx_80 = np.where(SOC_arr >= cfg.SOC_target)[0]
    t80 = float(t_arr[idx_80[0]]) if len(idx_80) > 0 else float(cfg.sim_time_s)

    # overshoot beyond V_max
    ov = float(max(0.0, V_arr.max() - cfg.V_max) * 1000.0)  # mV

    # Voltage tracking RMSE (vs V_max setpoint during CC ramp + CV hold)
    # Only evaluate from when CC phase starts lifting V above OCV start
    ref = _v_target_profile(t_arr, cfg.V_max)
    err = V_arr - ref
    rmse_V = float(np.sqrt(np.mean(err**2)) * 1000.0)  # mV

    # Energy efficiency
    # E_stored = integral of OCV·I dt ; E_input = integral of V·I dt
    # (both across charge only, so I>0)
    ocv_tr = np.array([ocv_of(p, s) for s in SOC_arr])
    e_stored = float(np.trapezoid(ocv_tr * I_arr, t_arr))
    e_input  = float(np.trapezoid(V_arr * I_arr, t_arr))
    eff = (e_stored / e_input) if e_input > 1e-6 else 0.0

    dT_max = float(T_arr.max() - cfg.T_amb)

    return TrialResult(
        t=t_arr, V=V_arr, V_meas=Vm_arr, I=I_arr, SOC=SOC_arr, T=T_arr,
        eK=eK_arr, R_corr=R_arr, V_pred_K=Vp_arr,
        t80=t80, ov=ov, rmse_V=rmse_V, eff=eff, dT_max=dT_max,
        method=method, scenario=cfg.scenario,
    )


# ─────────────────────────────────────────────────────────────────────────────
# Monte Carlo driver
# ─────────────────────────────────────────────────────────────────────────────
METHODS = ["CC-CV", "MPC", "K-only", "KR"]


def run_scenario(cell, cfg: SimConfig, n_trials: int = 20,
                 methods=METHODS, base_seed: int = 0) -> Dict[str, List[TrialResult]]:
    out = {m: [] for m in methods}
    for m in methods:
        for n in range(n_trials):
            seed = base_seed + hash((cell.name, cfg.scenario, m, n)) % (2**31)
            r = run_trial(cell, cfg, m, seed)
            out[m].append(r)
    return out


def scenario_configs(cell) -> Dict[str, SimConfig]:
    """Three scenarios per cell, matching the existing results-report schema.
    Current rates scale with the cell's fresh capacity (1C ≈ Q_nom_Ah)."""
    Q = cell.soh_baseline if cell.soh_baseline else 2.0
    I1C  = Q                    # 1C current
    Imax = 1.5 * Q              # 1.5C absolute limit
    return {
        "S1": SimConfig(scenario="S1", T_amb=25.0, SOH=1.00,
                        Imax=Imax, I_cc=I1C,       mismatch=0.15, sim_time_s=3600.0),
        "S3": SimConfig(scenario="S3", T_amb=5.0,  SOH=1.00,
                        Imax=Imax, I_cc=0.5*I1C,   mismatch=0.25, sim_time_s=3800.0),
        "S5": SimConfig(scenario="S5", T_amb=25.0, SOH=0.80,
                        Imax=Imax, I_cc=I1C,       mismatch=0.35, sim_time_s=3600.0),
    }


# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    # Quick sanity check: run one trial with fresh cell from DS5
    from nasa_loader import calibrate_cell
    import time

    print("Loading DS5 RW20 …")
    cell = calibrate_cell(
        "/home/claude/work/ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat",
        "DS5_Skewed_High_RT", "RW20")
    cfgs = scenario_configs(cell)
    cfg = cfgs["S1"]

    for m in METHODS:
        t0 = time.time()
        r = run_trial(cell, cfg, m, seed=42)
        print(f"{m:7s}  t80={r.t80:6.1f}s  rmse_V={r.rmse_V:6.2f}mV  "
              f"ov={r.ov:6.2f}mV  eff={r.eff*100:5.2f}%  dT={r.dT_max:5.2f}°C  "
              f"(wall {time.time()-t0:.2f}s)")
