"""
Battery plant model (ground truth) and controller implementations.

Plant:  2-RC Thevenin ECM + lumped thermal (matching §3 and §11 of the
        simulation parameters document). Parameters calibrated from NASA
        Randomized Battery Usage reference discharges.

Controllers (§4, §9, §16):
  - CC-CV         (classical baseline)
  - MPC           (short-horizon receding-horizon QP on current)
  - K-only        (ECM inverse, ∆V = V_target − V_pred  → I*)
  - K-R           (K-only + nonlinear residual corrector,
                   closed-form LS on 16-dim fixed ReLU reservoir;
                   retrained every 10 s from a 60-sample sliding window)
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import Callable, Tuple, Dict


# ─────────────────────────────────────────────────────────────────────────────
# Plant
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class PlantParams:
    # electrical
    Q_nom_Ah: float          # fresh capacity
    ocv_table: np.ndarray    # (M,2)  SOC∈[0,1] → OCV (V)
    R_int_25C: float         # internal resistance at 25 °C (Ω), fresh
    R1:  float = 0.006       # RC1 resistance (Ω)
    C1:  float = 3500.0      # RC1 capacitance (F)
    R2:  float = 0.008       # RC2 resistance (Ω)
    C2:  float = 60000.0     # RC2 capacitance (F)
    # thermal
    mass_g: float = 44.0     # 18650 cell mass
    Cp:     float = 1020.0   # J/(kg·K)
    hA:     float = 0.15     # effective h·A (W/K) — forced convection cylindrical 18650
    T_amb:  float = 25.0     # °C
    # aging
    SOH: float = 1.0         # state-of-health ∈ (0,1]


def arrhenius_R(R_ref: float, T_C: float, Ea_J: float = 25000.0) -> float:
    """Resistance temperature scaling (Arrhenius). T_ref = 25 °C = 298.15 K."""
    R_gas = 8.314
    T_K = T_C + 273.15
    T_ref = 298.15
    return R_ref * float(np.exp((Ea_J / R_gas) * (1.0 / T_K - 1.0 / T_ref)))


def soh_scale_R(R_ref: float, SOH: float) -> float:
    """R_int grows ~0.8·(1−SOH). (§12.1 aging model.)"""
    return R_ref * (1.0 + 0.8 * (1.0 - SOH))


@dataclass
class PlantState:
    SOC: float
    V_RC1: float = 0.0
    V_RC2: float = 0.0
    T_C: float = 25.0


def ocv_of(p: PlantParams, SOC: float) -> float:
    SOC = float(np.clip(SOC, 0.0, 1.0))
    return float(np.interp(SOC, p.ocv_table[:, 0], p.ocv_table[:, 1]))


def plant_step(p: PlantParams, x: PlantState, I: float, dt: float,
               T_amb: float = None) -> Tuple[PlantState, float]:
    """
    One integration step of the 2-RC Thevenin + lumped thermal model.
    Convention:  I > 0 is CHARGE current (into the cell).
    Returns (new_state, V_terminal).
    """
    if T_amb is None:
        T_amb = p.T_amb

    # temperature- and SOH-adjusted parameters
    R_int = soh_scale_R(arrhenius_R(p.R_int_25C, x.T_C), p.SOH)
    R1    = soh_scale_R(arrhenius_R(p.R1,        x.T_C), p.SOH)
    R2    = soh_scale_R(arrhenius_R(p.R2,        x.T_C), p.SOH)

    # RC voltage dynamics (semi-implicit Euler, unconditionally stable)
    tau1 = R1 * p.C1
    tau2 = R2 * p.C2
    V_RC1_new = (x.V_RC1 + (I / p.C1) * dt) / (1.0 + dt / tau1)
    V_RC2_new = (x.V_RC2 + (I / p.C2) * dt) / (1.0 + dt / tau2)

    # SOC update (Coulomb counting; I>0 charges, capacity scaled by SOH)
    Q_eff = p.Q_nom_Ah * p.SOH
    SOC_new = x.SOC + I * dt / (Q_eff * 3600.0)
    SOC_new = float(np.clip(SOC_new, 0.0, 1.0))

    # Terminal voltage
    OCV = ocv_of(p, 0.5 * (x.SOC + SOC_new))
    V_t = OCV + I * R_int + V_RC1_new + V_RC2_new

    # Lumped thermal: I²R ohmic + reversible term ignored (order of mag smaller)
    q_gen = I * I * R_int + I * (V_RC1_new + V_RC2_new)
    m_kg = p.mass_g / 1000.0
    dT = (q_gen - p.hA * (x.T_C - T_amb)) / (m_kg * p.Cp) * dt
    T_new = x.T_C + dT

    return PlantState(SOC=SOC_new, V_RC1=V_RC1_new, V_RC2=V_RC2_new, T_C=T_new), V_t


# ─────────────────────────────────────────────────────────────────────────────
# Model observer (controller's internal ECM state — separate from the plant)
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class ECMObserver:
    """
    The controller's internal belief about the cell state.
    Integrates the SAME structural ECM as the plant, but with the controller's
    (mismatched) parameters. Used for causally correct residual computation.

    V_model(t | t-1) = OCV(SOC_obs(t-1)) + I_prev·R_int_m + V1_obs(t-1) + V2_obs(t-1)
    eK(t)            = V_meas(t) − V_model(t | t-1)
    """
    SOC: float
    V1: float = 0.0
    V2: float = 0.0
    # controller's believed parameters (may differ from plant by `mismatch`)
    R_int_m: float = 0.08
    R1_m:    float = 0.006
    C1_m:    float = 3500.0
    R2_m:    float = 0.008
    C2_m:    float = 60000.0
    Q_nom_Ah: float = 2.0

    def predict_V(self, I_prev: float, ocv_table: np.ndarray) -> float:
        """Predict V at current time given previous step's applied current and the
        observer's own state (SOC_prev, V1_prev, V2_prev)."""
        OCV = float(np.interp(self.SOC, ocv_table[:, 0], ocv_table[:, 1]))
        return OCV + I_prev * self.R_int_m + self.V1 + self.V2

    def propagate(self, I_applied: float, dt: float) -> None:
        """Integrate the observer's ECM state one step using the applied current."""
        tau1 = self.R1_m * self.C1_m
        tau2 = self.R2_m * self.C2_m
        self.V1 = (self.V1 + (I_applied / self.C1_m) * dt) / (1.0 + dt / tau1)
        self.V2 = (self.V2 + (I_applied / self.C2_m) * dt) / (1.0 + dt / tau2)
        self.SOC = float(np.clip(
            self.SOC + I_applied * dt / (self.Q_nom_Ah * 3600.0), 0.0, 1.0))


def _clip_current(I: float, I_max: float) -> float:
    return float(np.clip(I, 0.0, I_max))


# ---- 1. CC-CV ---------------------------------------------------------------
def controller_cccv(V_target: float, I_target: float,
                    I_meas: float, V_meas: float, T_meas: float,
                    SOC_est: float, SOH: float,
                    I_max: float, cv_cutoff_ratio: float = 0.05) -> float:
    """Classical CC-CV:  CC at I_target until V_meas ≥ V_target, then CV with a
    proportional controller shrinking I as V tracks V_target, down to cv_cutoff."""
    if V_meas < V_target - 0.005:
        return _clip_current(I_target, I_max)
    # CV phase: shrink current, PI-like on voltage error
    err = V_target - V_meas
    # proportional gain chosen so ~100 mV → full current change
    kp = max(I_target, 1.0) / 0.10
    I = kp * err
    # enforce monotonic non-increasing CV taper (standard CC-CV)
    return _clip_current(I, I_max)


# ---- 2. MPC (short-horizon) -------------------------------------------------
def controller_mpc(V_target: float, I_target: float,
                   I_meas: float, V_meas: float, T_meas: float,
                   SOC_est: float, SOH: float,
                   I_max: float, model_R: float) -> float:
    """
    Simplified receding-horizon MPC: given model impedance, compute current that
    places V one step ahead at V_target, with penalty on |ΔI|.
    (The exact reviewer-expected MPC is a QP; this is a 1-step closed-form MPC
    that ignores RC transients — a standard simplified benchmark.)
    """
    # V_next ≈ V_meas + model_R · ΔI  →  ΔI = (V_target − V_meas)/model_R
    err = V_target - V_meas
    I_cmd = I_meas + err / max(model_R, 1e-3)
    # ΔI rate-limit: 0.5 A / s
    dI_max = 0.5
    I_cmd = max(I_meas - dI_max, min(I_meas + dI_max, I_cmd))
    return _clip_current(I_cmd, I_max)


# ---- 3. K-only (ECM inverse) ------------------------------------------------
def controller_K(V_target: float, I_target: float,
                 I_meas: float, V_meas: float, T_meas: float,
                 SOC_est: float, SOH: float,
                 I_max: float,
                 ocv_table: np.ndarray,
                 model_Rtot: float,
                 obs: "ECMObserver" = None) -> Tuple[float, float, float]:
    """
    K step only:  invert model steady-state impedance  I = (V_target − OCV) / R_tot.
    If an observer is provided, we also use it to bring RC polarization into
    the inversion: I = (V_target − OCV − V1 − V2) / R_int_m — this is the
    linearized steady-state-inverse that matches the residual's model.
    Returns (I_cmd, V_pred, OCV_used).
    """
    OCV = float(np.interp(SOC_est, ocv_table[:, 0], ocv_table[:, 1]))
    if obs is not None:
        # Linearized inverse: what I holds V_term = V_target given current RC state?
        I_cmd = (V_target - OCV - obs.V1 - obs.V2) / max(obs.R_int_m, 1e-3)
        V_pred = OCV + I_cmd * obs.R_int_m + obs.V1 + obs.V2
    else:
        I_cmd = (V_target - OCV) / max(model_Rtot, 1e-3)
        V_pred = OCV + I_cmd * model_Rtot
    I_cmd = _clip_current(I_cmd, I_max)
    return I_cmd, V_pred, OCV


# ---- 4. K-R (K + residual learner) ------------------------------------------
class KR_Residual:
    """
    Nonlinear residual corrector as described in §4.2 / §16.
      features f ∈ R^8 : [eK, eK², T, SOC, eK·SOC, SOH, T·eK, dT/dt]
      hidden  h = ReLU(f · W1)  with W1 ∈ R^(8×16) Xavier-initialised, FIXED
      output  R = h · w2                                (w2 trained online)
      training: closed-form ridge regression every 10 s on a 60-sample buffer
    """
    def __init__(self, rng: np.random.Generator, hidden: int = 16,
                 d_f: int = 8, lam: float = 1e-4,
                 buf_N: int = 60, retrain_every_s: float = 10.0,
                 R_max_A: float = 0.5,
                 preset_W1: np.ndarray = None,
                 preset_w2: np.ndarray = None,
                 freeze_w2: bool = False):
        self.h = hidden
        self.d = d_f
        self.lam = lam
        self.buf_N = buf_N
        self.retrain_every_s = retrain_every_s
        self.R_max = R_max_A
        self.freeze_w2 = freeze_w2
        # Xavier: U(−√(6/(d+h)), √(6/(d+h)))
        if preset_W1 is not None:
            self.W1 = preset_W1.copy()
        else:
            a = float(np.sqrt(6.0 / (d_f + hidden)))
            self.W1 = rng.uniform(-a, a, size=(d_f, hidden))
        self.w2 = preset_w2.copy() if preset_w2 is not None else np.zeros(hidden)
        self.buffer = []          # list of (f, target_I_correction)
        self.t_last_train = -np.inf
        self.dT_prev = 0.0
        self.T_prev = None

    def features(self, eK: float, T: float, SOC: float, SOH: float, t: float) -> np.ndarray:
        if self.T_prev is None:
            dT_dt = 0.0
        else:
            dT_dt = (T - self.T_prev)
        self.T_prev = T
        return np.array([eK, eK*eK, T, SOC, eK*SOC, SOH, T*eK, dT_dt], dtype=np.float64)

    def predict(self, f: np.ndarray) -> float:
        h = np.maximum(self.W1.T @ f, 0.0)       # ReLU, shape (hidden,)
        r = float(h @ self.w2)
        return float(np.clip(r, -self.R_max, self.R_max))

    def record_and_maybe_train(self, f: np.ndarray, eK: float, t: float,
                               model_Rint: float) -> None:
        """
        Record sample; retrain w2 every retrain_every_s seconds.
        Target: the current correction that would have erased the prior-step eK,
        i.e. ΔI = −eK / R_int  (first-order linearization).
        """
        if self.freeze_w2:
            return
        target = -eK / max(model_Rint, 1e-3)
        target = float(np.clip(target, -self.R_max, self.R_max))
        self.buffer.append((f.copy(), target))
        if len(self.buffer) > self.buf_N:
            self.buffer = self.buffer[-self.buf_N:]
        if (t - self.t_last_train) >= self.retrain_every_s and len(self.buffer) >= 20:
            F = np.stack([b[0] for b in self.buffer])   # (N, d)
            y = np.array([b[1] for b in self.buffer])   # (N,)
            H = np.maximum(F @ self.W1, 0.0)            # (N, hidden)
            # Ridge regression: w2 = (HᵀH + λI)⁻¹ Hᵀy
            A = H.T @ H + self.lam * np.eye(self.h)
            b = H.T @ y
            try:
                self.w2 = np.linalg.solve(A, b)
            except np.linalg.LinAlgError:
                self.w2 = np.linalg.lstsq(A, b, rcond=None)[0]
            self.t_last_train = t


def controller_KR(V_target: float, I_target: float,
                  I_meas: float, V_meas: float, T_meas: float,
                  SOC_est: float, SOH: float,
                  I_max: float,
                  ocv_table: np.ndarray,
                  model_Rtot: float,
                  kr: KR_Residual,
                  t: float,
                  I_prev: float,
                  obs: "ECMObserver") -> Tuple[float, float, float, float]:
    """
    K-R step:  I_final = I_K + R_correction.

    Residual is a CAUSAL ONE-STEP-AHEAD PREDICTION ERROR that includes
    FULL RC POLARIZATION (ohmic + both RC pairs):

        V_model(t | t-1) = OCV(SOC_obs(t-1)) + I_prev · R_int_m
                            + V1_obs(t-1) + V2_obs(t-1)
        eK(t)            = V_meas(t) − V_model(t | t-1)

    This closes the information-leakage attack: a residual that ignored RC
    states would absorb polarization dynamics into what is supposed to be
    the aging/temperature mismatch term, artificially inflating it.
    Returns (I_cmd, eK, R_correction, V_model_prediction).
    """
    # Causal, RC-aware one-step-ahead prediction using the observer
    V_predicted = obs.predict_V(I_prev, ocv_table)
    eK = V_meas - V_predicted

    # K-step at current SOC, using current RC state
    I_K, _, _ = controller_K(V_target, I_target, I_meas, V_meas, T_meas,
                             SOC_est, SOH, I_max, ocv_table, model_Rtot, obs=obs)

    # R-step correction
    f = kr.features(eK, T_meas, SOC_est, SOH, t)
    R_corr = kr.predict(f)
    I_final = _clip_current(I_K + R_corr, I_max)

    # Online training (if not frozen)
    kr.record_and_maybe_train(f, eK, t, obs.R_int_m)
    return I_final, eK, R_corr, V_predicted
