"""
NASA PCoE Randomized Battery Usage — loader and calibrator.

Extracts per-cell:
  - Reference discharge cycles → capacity (Ah) vs date → SOH trajectory
  - OCV–SOC table from early reference discharge (low current + rest relaxations)
  - R_int evolution from voltage step at the start of charge cycles
  - Representative charge cycles (V, I, T, t) at selected SOH points

Ref:  NASA Prognostics Center of Excellence, Randomized Battery Usage Data,
      Bole, Kulkarni, Daigle (2014).
"""
from __future__ import annotations
import numpy as np
import scipy.io as sio
from dataclasses import dataclass, field
from typing import Dict, List, Tuple


@dataclass
class Cycle:
    step_idx: int
    comment: str
    t: np.ndarray           # relative time (s)
    V: np.ndarray           # voltage (V)
    I: np.ndarray           # current (A) — NASA sign convention: discharge = +, charge = -
    T: np.ndarray           # temperature (°C)
    date: str
    day: float              # decimal days from start

    @property
    def duration(self) -> float:
        return float(self.t[-1] - self.t[0]) if len(self.t) > 1 else 0.0

    @property
    def capacity_Ah(self) -> float:
        # trapezoidal integration of |I| over time in hours
        dt_h = (self.t[-1] - self.t[0]) / 3600.0
        if dt_h <= 0:
            return 0.0
        return float(np.trapezoid(np.abs(self.I), self.t) / 3600.0)


@dataclass
class CellData:
    name: str
    dataset_tag: str
    ref_discharges: List[Cycle] = field(default_factory=list)
    ref_charges: List[Cycle] = field(default_factory=list)
    rw_charges:  List[Cycle] = field(default_factory=list)  # "charge (after random walk discharge)"

    # Calibration outputs
    soh_trajectory: np.ndarray = None  # (N, 2): [day, SOH]
    soh_baseline: float = None         # fresh-cell capacity (Ah)
    ocv_table: np.ndarray = None       # (M, 2): [SOC 0..1, OCV V]
    rint_vs_soc: np.ndarray = None     # (M, 2): [SOC, R_int Ω] — fresh cell


def _date_to_day(date_str: str, ref_str: str) -> float:
    """Decimal days from ref_str date."""
    from datetime import datetime
    fmt = "%d-%b-%Y %H:%M:%S"
    d = datetime.strptime(date_str, fmt)
    r = datetime.strptime(ref_str, fmt)
    return (d - r).total_seconds() / 86400.0


def load_cell(mat_path: str, dataset_tag: str, cell_name: str) -> CellData:
    """Load and index a single NASA RW cell."""
    m = sio.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
    data = m['data']
    steps = data.step
    n = len(steps)

    # anchor date = first step
    ref_date = steps[0].date

    cell = CellData(name=cell_name, dataset_tag=dataset_tag)

    for k in range(n):
        s = steps[k]
        comment = str(s.comment) if hasattr(s, 'comment') else ''
        # guard against zero-length steps
        v = np.atleast_1d(s.voltage).astype(np.float64).ravel()
        i = np.atleast_1d(s.current).astype(np.float64).ravel()
        t = np.atleast_1d(s.relativeTime).astype(np.float64).ravel()
        T = np.atleast_1d(s.temperature).astype(np.float64).ravel() if hasattr(s, 'temperature') else np.full_like(v, 25.0)
        if len(v) < 5 or len(t) < 5:
            continue
        # align lengths conservatively
        L = min(len(v), len(i), len(t), len(T))
        cyc = Cycle(step_idx=k, comment=comment,
                    t=t[:L], V=v[:L], I=i[:L], T=T[:L],
                    date=str(s.date), day=_date_to_day(str(s.date), ref_date))

        if comment == 'reference discharge':
            cell.ref_discharges.append(cyc)
        elif comment == 'reference charge':
            cell.ref_charges.append(cyc)
        elif comment == 'charge (after random walk discharge)':
            cell.rw_charges.append(cyc)

    return cell


def calibrate_soh(cell: CellData) -> None:
    """Compute capacity (Ah) per reference discharge → SOH trajectory."""
    if not cell.ref_discharges:
        return
    pts = []
    for c in cell.ref_discharges:
        # reference discharge is a constant-current discharge from 4.2V to 3.2V (~2 A)
        # capacity = integral of |I| dt in hours
        Ah = c.capacity_Ah
        if 0.3 < Ah < 3.0:   # plausible range for a 2 Ah cell
            pts.append((c.day, Ah))
    pts = np.array(sorted(pts))
    cell.soh_trajectory = pts
    # baseline = mean of first 3 capacity measurements (fresh cell)
    if len(pts) >= 3:
        cell.soh_baseline = float(np.mean(pts[:3, 1]))
    elif len(pts) > 0:
        cell.soh_baseline = float(pts[0, 1])


def fit_ocv_soc(cell: CellData, n_bins: int = 50) -> None:
    """
    Build an OCV–SOC table from the first (freshest) reference discharge.
    Reference discharge is 1 A CC (~0.5 C for a 2 Ah cell) from ~4.12 V → 3.20 V.
    Terminal V = OCV − |I|·R_int (discharge convention, I > 0). So OCV = V + |I|·R_int.
    SOC is integrated from known starting point (~100% after full CC-CV charge + rest).
    """
    if not cell.ref_discharges:
        return
    c = cell.ref_discharges[0]
    V, I, t = c.V, c.I, c.t
    if len(V) < 20:
        return

    # Estimate R_int from the initial current-on voltage step
    # (ref discharge starts after a rest, so voltage jumps when current turns on)
    Rint0 = 0.08  # default LCO 18650 internal resistance
    for k in range(1, min(10, len(V))):
        if I[k] > 0.3 and abs(I[k-1]) < 0.1:
            dV = V[k] - V[k-1]
            dI = I[k] - I[k-1]
            if abs(dI) > 0.05:
                r = abs(dV / dI)
                if 0.02 < r < 0.25:
                    Rint0 = r
            break

    # Integrate discharged Ah
    q = np.zeros_like(V)
    for k in range(1, len(V)):
        q[k] = q[k-1] + I[k] * (t[k] - t[k-1]) / 3600.0   # Ah (positive)
    q_full = q[-1]
    if q_full <= 0:
        return
    # Discharge: SOC goes from 1 down to 0
    soc = 1.0 - q / q_full

    # Terminal V = OCV − I·R_int  →  OCV = V + I·R_int  (I > 0 during discharge)
    ocv_est = V + I * Rint0

    # Only use the CC part (avoid rest/CV-style tails)
    cc_mask = I > 0.3

    # Bin by SOC (descending discharge) and average
    bins = np.linspace(0.02, 0.98, n_bins)
    ocv_tbl = []
    for b in range(len(bins)-1):
        mask = cc_mask & (soc >= bins[b]) & (soc < bins[b+1])
        if np.sum(mask) >= 1:
            ocv_tbl.append((0.5*(bins[b]+bins[b+1]), float(np.mean(ocv_est[mask]))))
    ocv_tbl = np.array(ocv_tbl)
    if len(ocv_tbl) < 5:
        return
    # Sort ascending by SOC
    ocv_tbl = ocv_tbl[ocv_tbl[:,0].argsort()]
    # Enforce monotonic (non-decreasing with SOC)
    ocv_tbl[:,1] = np.maximum.accumulate(ocv_tbl[:,1])
    # Anchor endpoints to datasheet-typical limits
    if ocv_tbl[0, 0] > 0.05:
        ocv_tbl = np.vstack([[0.0, 3.20], ocv_tbl])
    if ocv_tbl[-1, 0] < 0.95:
        ocv_tbl = np.vstack([ocv_tbl, [1.0, 4.20]])
    cell.ocv_table = ocv_tbl
    cell._Rint0 = Rint0


def fit_rint_vs_soc(cell: CellData) -> None:
    """
    Estimate R_int from reference DISCHARGE cycles:
      - Reference discharge starts right after a full charge + rest (SOC ≈ 100%).
      - At the rest-end voltage V_rest ≈ OCV(≈1.0) ≈ 4.18 V.
      - At t=0 of the discharge, current is I_d (≈1 A) and V drops instantly to V_0
        due to ohmic IR drop.  R_int = (V_rest − V_0) / I_d.
      - Since we don't have the preceding rest voltage stored separately for every
        cell, approximate V_rest = OCV(table, 1.0) = ocv_table[-1,1].
    Builds an SOC-constant R_int_fresh; trajectory along SOC via datasheet shape.
    """
    if cell.ocv_table is None or not cell.ref_discharges:
        cell.rint_vs_soc = np.array([[0.05, 0.085],
                                     [0.50, 0.075],
                                     [0.95, 0.085]])
        return

    ocv_100 = float(cell.ocv_table[-1, 1])
    Rs = []
    # use first 5 fresh reference discharges to average out noise
    for c in cell.ref_discharges[:5]:
        V, I = c.V, c.I
        if len(V) < 3:
            continue
        I0 = float(I[0])
        V0 = float(V[0])
        if I0 < 0.2:
            continue
        R = (ocv_100 - V0) / I0
        # accept physically plausible values only
        if 0.03 < R < 0.20:
            Rs.append(R)
    R_fresh = float(np.median(Rs)) if Rs else 0.080  # 80 mΩ fallback

    # Impose datasheet-shape SOC dependence (U-curve: higher at low/high SOC,
    # minimum near 50%)  — matches §3.2 of the simulation parameters document.
    shape = {0.05: 1.25, 0.10: 1.15, 0.25: 1.05, 0.50: 1.00,
             0.75: 1.05, 0.90: 1.15, 0.95: 1.25}
    cell.rint_vs_soc = np.array([[s, R_fresh * k] for s, k in sorted(shape.items())])
    cell._Rint_fresh = R_fresh


def calibrate_cell(mat_path: str, dataset_tag: str, cell_name: str) -> CellData:
    cell = load_cell(mat_path, dataset_tag, cell_name)
    calibrate_soh(cell)
    fit_ocv_soc(cell)
    fit_rint_vs_soc(cell)
    return cell


# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import os, json
    mats = [
        ("ds1/Battery_Uniform_Distribution_Charge_Discharge_DataSet_2Post/data/Matlab/RW9.mat",
         "DS1_Uniform_ChgDis", "RW9"),
        ("ds2/Battery_Uniform_Distribution_Discharge_Room_Temp_DataSet_2Post/data/Matlab/RW3.mat",
         "DS2_Uniform_DisRT", "RW3"),
        ("ds5/RW_Skewed_High_Room_Temp_DataSet_2Post/data/Matlab/RW20.mat",
         "DS5_Skewed_High_RT", "RW20"),
        ("ds7/RW_Skewed_Low_Room_Temp_DataSet_2Post/data/Matlab/RW13.mat",
         "DS7_Skewed_Low_RT", "RW13"),
    ]
    base = "/home/claude/work"
    for rel, tag, name in mats:
        print(f"--- {tag} {name} ---")
        cell = calibrate_cell(os.path.join(base, rel), tag, name)
        print(f"  ref_discharge cycles : {len(cell.ref_discharges)}")
        print(f"  ref_charge cycles    : {len(cell.ref_charges)}")
        print(f"  RW charge cycles     : {len(cell.rw_charges)}")
        if cell.soh_baseline:
            print(f"  baseline capacity    : {cell.soh_baseline:.3f} Ah")
        if cell.soh_trajectory is not None and len(cell.soh_trajectory) > 1:
            sohs = cell.soh_trajectory[:,1] / cell.soh_baseline
            print(f"  SOH range            : {sohs.min():.3f}–{sohs.max():.3f}")
            print(f"  cycles over {cell.soh_trajectory[-1,0]:.1f} days, {len(cell.soh_trajectory)} refs")
        if cell.ocv_table is not None:
            print(f"  OCV table            : {len(cell.ocv_table)} bins, "
                  f"V@SOC=0.5 = {float(np.interp(0.5, cell.ocv_table[:,0], cell.ocv_table[:,1])):.3f} V")
        if cell.rint_vs_soc is not None:
            print(f"  R_int table          : {len(cell.rint_vs_soc)} bins, "
                  f"R@SOC=0.5 = {float(np.interp(0.5, cell.rint_vs_soc[:,0], cell.rint_vs_soc[:,1]))*1000:.1f} mΩ")
