#!/usr/bin/env python3
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Create a results folder relative to where the script is run
RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True)

G = 9.80665  # m/s^2

def tnorm(rng, mean, sd, lo, hi, size):
    x = rng.normal(mean, sd, size)
    return np.clip(x, lo, hi)

def compute_sigma_fix(df, iters=30000, seed=98765, r_mean_mm=1.0, r_sd_mm=0.25,
                      beta_mean=0.90, beta_sd=0.05):
    """
    Compute sigma (drop-weight only) with Yildirim factor prior.
    Returns a new DataFrame with *_FIX columns populated on drop-weight rows, NaN elsewhere.
    Drop-weight rows are identified where at least 4 of M1..M6 are numeric and 0<Mi<2 g.
    """
    out = df.copy()
    # Identify M columns and coerce numeric
    Mcols = [c for c in out.columns if str(c).startswith("M")]
    num = out.copy()
    for c in Mcols:
        num[c] = pd.to_numeric(num[c], errors="coerce")
    valid_counts = ((num[Mcols] > 0) & (num[Mcols] < 2)).sum(axis=1)
    is_dropweight = valid_counts >= 4

    # Per-drop mass (from 20-drop masses)
    m20_mean = num.loc[is_dropweight, Mcols].mean(axis=1)
    m20_sd   = num.loc[is_dropweight, Mcols].std(axis=1, ddof=1).fillna(0.0)
    m_drop_mean_g = m20_mean / 20.0
    fallback = m20_mean.std()/20.0 if m20_mean.std() > 0 else 1e-6
    m_drop_sd_g = (m20_sd / 20.0).replace(0, fallback)

    rng = np.random.default_rng(seed)

    rows = []
    for mu_g, sd_g in zip(m_drop_mean_g, m_drop_sd_g):
        mu_g = float(mu_g)
        sd_g = float(max(sd_g, mu_g*0.05))  # ensure at least 5% rel SD
        m_samp = tnorm(rng, mu_g, sd_g, max(1e-6, mu_g*0.5), mu_g*1.5, iters) * 1e-3  # kg
        r_samp = tnorm(rng, r_mean_mm*1e-3, r_sd_mm*1e-3, 0.4e-3, 2.0e-3, iters)
        beta_samp = tnorm(rng, beta_mean, beta_sd, 0.70, 1.10, iters)
        sigma = (m_samp * G) / (2*np.pi * r_samp * beta_samp)  # N/m
        rows.append({
            "sigma_mean_N_per_m_FIX": float(np.mean(sigma)),
            "sigma_lo_N_per_m_FIX": float(np.percentile(sigma, 2.5)),
            "sigma_hi_N_per_m_FIX": float(np.percentile(sigma, 97.5)),
        })
    sig = pd.DataFrame(rows, index=m_drop_mean_g.index)
    # Initialize columns with NaN, then fill for drop-weight indices
    for col in ["sigma_mean_N_per_m_FIX","sigma_lo_N_per_m_FIX","sigma_hi_N_per_m_FIX"]:
        out[col] = np.nan
    out.loc[sig.index, ["sigma_mean_N_per_m_FIX","sigma_lo_N_per_m_FIX","sigma_hi_N_per_m_FIX"]] = sig.values
    out["sigma_mean_mN_per_m_FIX"] = out["sigma_mean_N_per_m_FIX"]*1e3
    out["sigma_lo_mN_per_m_FIX"]  = out["sigma_lo_N_per_m_FIX"]*1e3
    out["sigma_hi_mN_per_m_FIX"]  = out["sigma_hi_N_per_m_FIX"]*1e3

    return out, is_dropweight

def make_plot(df, out_prefix="figure_sigma_vs_concentration_FIXED"):
    """
    Build plot using corrected sigma (FIX columns) vs concentration with 95% CIs on both axes.
    Requires columns: c_mmol_per_L_mean, c_mmol_per_L_lo, c_mmol_per_L_hi,
                      sigma_mean_mN_per_m_FIX, sigma_lo_mN_per_m_FIX, sigma_hi_mN_per_m_FIX
    """
    needed = ["c_mmol_per_L_mean","c_mmol_per_L_lo","c_mmol_per_L_hi",
              "sigma_mean_mN_per_m_FIX","sigma_lo_mN_per_m_FIX","sigma_hi_mN_per_m_FIX"]
    for c in needed:
        if c not in df.columns:
            print(f"[WARN] Missing column for plot: {c}")
    mask = df["sigma_mean_mN_per_m_FIX"].notna() & df["c_mmol_per_L_mean"].notna()
    plot_df = df.loc[mask, needed].copy()

    if plot_df.empty:
        print("[WARN] No rows available for plotting after filtering.")
        return None, None, None

    x = plot_df["c_mmol_per_L_mean"].to_numpy()
    y = plot_df["sigma_mean_mN_per_m_FIX"].to_numpy()
    xerr = np.vstack([
        (plot_df["c_mmol_per_L_mean"] - plot_df["c_mmol_per_L_lo"]).to_numpy(),
        (plot_df["c_mmol_per_L_hi"] - plot_df["c_mmol_per_L_mean"]).to_numpy()
    ])
    yerr = np.vstack([
        (plot_df["sigma_mean_mN_per_m_FIX"] - plot_df["sigma_lo_mN_per_m_FIX"]).to_numpy(),
        (plot_df["sigma_hi_mN_per_m_FIX"] - plot_df["sigma_mean_mN_per_m_FIX"]).to_numpy()
    ])

    plt.figure(figsize=(8,6))
    plt.errorbar(x, y, xerr=xerr, yerr=yerr, fmt='o', capsize=3)
    plt.xlabel("Surfactant Concentration (mmol/L)")
    plt.ylabel("Surface Tension (mN/m)")
    plt.title("Drop-weight Surface Tension vs Concentration (95% CI, corrected)")
    plt.grid(True, which='both', linestyle='--', alpha=0.4)
    png = f"{out_prefix}.png"
    pdf = f"{out_prefix}.pdf"
    plt.tight_layout()
    plt.savefig(png, dpi=200)
    plt.savefig(pdf)
    plt.close()

    plot_csv = "plot_data_sigma_vs_conc_FIXED.csv"
    plot_df.to_csv(plot_csv, index=False)
    return png, pdf, plot_csv

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Merged CSV (with M1..M6 + c_mmol_per_L_* columns)")
    ap.add_argument("--output", required=True, help="Output CSV with *_FIX sigma columns added")
    ap.add_argument("--iters", type=int, default=30000)
    ap.add_argument("--seed", type=int, default=98765)
    ap.add_argument("--r-mean-mm", type=float, default=1.0)
    ap.add_argument("--r-sd-mm", type=float, default=0.25)
    ap.add_argument("--beta-mean", type=float, default=0.90)
    ap.add_argument("--beta-sd", type=float, default=0.05)
    args = ap.parse_args()

    df = pd.read_csv(args.input)

    df_fixed, mask = compute_sigma_fix(
        df,
        iters=args.iters, seed=args.seed,
        r_mean_mm=args.r_mean_mm, r_sd_mm=args.r_sd_mm,
        beta_mean=args.beta_mean, beta_sd=args.beta_sd
    )
    df_fixed.to_csv(args.output, index=False)
    print(f"[OK] Wrote fixed CSV: {args.output}")
    print(f"[INFO] Drop-weight rows detected: {int(mask.sum())} / {len(mask)}")

    png, pdf, plot_csv = make_plot(df_fixed)
    if png:
        print(f"[OK] Saved plot: {png}")
        print(f"[OK] Saved plot: {pdf}")
        print(f"[OK] Saved plot data: {plot_csv}")
    else:
        print("[WARN] Plot not generated (no valid rows).")

if __name__ == "__main__":
    main()
