"""
Statistical Analysis Module for DNA Sequencing Data

This module performs statistical comparisons between two samples, calculating effect 
sizes and percentile distributions for various features at each base position.

Usage:
    python script.py <path_a> <path_b> <sample_a> <sample_b> <features> <out_percentiles> <out_stats>

Example:
    python script.py modified.parquet unmodified.parquet "Modified" "Unmodified" "mean,std,dwell" percentiles.pkl stats.pkl
"""

import sys
import pickle

import pandas as pd
import numpy as np
from scipy import stats
from statsmodels.stats.multitest import multipletests

def calculate_cohens_d(x: np.ndarray, y: np.ndarray) -> float:
    """
    Calculate Cohen's d for two independent samples.
    
    Parameters:
        x (np.ndarray): Sample 1
        y (np.ndarray): Sample 2
        
    Returns:
        float: Cohen's d effect size
    """
    mean_x = np.mean(x)
    mean_y = np.mean(y)
    std_x = np.std(x, ddof=1)
    std_y = np.std(y, ddof=1)
    
    n_x = len(x)
    n_y = len(y)
    pooled_std = np.sqrt(
        ((n_x - 1)*std_x**2 + (n_y - 1)*std_y**2) / (n_x + n_y - 2)
    )
    
    if pooled_std == 0:
        return 0.0

    d = (mean_x - mean_y) / pooled_std
    return d



def data_to_stats(
    d_a: pd.DataFrame, 
    d_b: pd.DataFrame, 
    features: list[str]
) -> pd.DataFrame:
    """
    Perform statistical comparison between two samples across all base positions.
    
    For each base position and feature, this function:
    1. Performs a two-sample Kolmogorov-Smirnov test to assess whether the
       distributions differ significantly
    2. Calculates Cohen's d to quantify the effect size
    3. Applies Bonferroni correction for multiple testing
    
    Args:
        d_a: DataFrame for sample A with columns including 'base_index', 'base',
             and feature columns
        d_b: DataFrame for sample B with matching structure to d_a
        features: List of feature column names to compare (e.g., ['mean', 'std', 'dwell'])
        
    Returns:
        DataFrame with columns:
            - base_index: Position in the sequence
            - base: Nucleotide base (A, C, G, T, or X if mismatch)
            - feature: Name of the feature being compared
            - stat: KS test statistic
            - p_val: Raw p-value from KS test
            - p_val_corrected: Bonferroni-corrected p-value
            - cohens_d: Effect size measure
    """
    d_stats = []

    for base_index in sorted(d_a["base_index"].unique()):
        for feature in features:
            s_a = d_a.loc[d_a["base_index"]==base_index]
            s_b = d_b.loc[d_b["base_index"]==base_index]
            subset_a = s_a[feature]
            subset_b = s_b[feature]

            # Calculate effect size
            cohens_d = calculate_cohens_d(subset_a, subset_b)

            # Perform Kolmogorov-Smirnov test
            stat_res = stats.ks_2samp(subset_a, subset_b)
            stat, p_val = stat_res[0], stat_res[1]

            # Determine base identity (mark as 'X' if samples disagree)
            base_a = s_a["base"].values[0]
            base_b = s_b["base"].values[0]
            base = base_a if base_a == base_b else "X"

            d_stats.append([
                base_index,
                base,
                feature,
                stat, p_val, cohens_d
            ])

    df_stats = pd.DataFrame(d_stats, columns=[
        "base_index", 
        "base", 
        "feature", 
        "stat", 
        "p_val", 
        "cohens_d"
    ])

    # Apply Bonferroni correction for multiple testing
    _, p_val_corrected, _, _ = multipletests(
        df_stats["p_val"],
        method="bonferroni"
    )

    df_stats["p_val_corrected"] = p_val_corrected
    df_stats = df_stats.reindex(columns=[
        "base_index", 
        "base", 
        "feature", 
        "stat", 
        "p_val",
        "p_val_corrected",
        "cohens_d"
    ]).sort_values(["feature", "base_index"])

    return df_stats



def data_to_percentiles(
    d: pd.DataFrame, 
    features: list[str]
) -> pd.DataFrame:
    """
    Calculate percentile distributions for each feature at each base position.
    
    Computes the 5th, 25th, 50th (median), 75th, and 95th percentiles to
    characterize the distribution of feature values at each base position.
    
    Args:
        d: DataFrame with columns including 'base_index', 'base', and feature columns
        features: List of feature column names to analyze
        
    Returns:
        DataFrame with columns:
            - base_index: Position in the sequence
            - base: Nucleotide base (A, C, G, or T)
            - feature: Name of the feature
            - p05: 5th percentile
            - p25: 25th percentile (Q1)
            - p50: 50th percentile (median)
            - p75: 75th percentile (Q3)
            - p95: 95th percentile
    """
    d_percentile = []

    for (base_index, subset) in d.groupby("base_index"):
        for feature in features: # ["mean", "std", "dwell", "signal_to_noise"]:
            p05, p25, p50, p75, p95 = subset[feature].quantile(
                [0.05, 0.25, 0.50, 0.75, 0.95]
            )

            d_percentile.append([
                base_index,
                subset["base"].values[0],
                feature,
                p05, p25, p50, p75, p95
            ])

    d_percentile = pd.DataFrame(
        d_percentile, 
        columns=[
            "base_index", 
            "base", 
            "feature", 
            "p05", 
            "p25", 
            "p50", 
            "p75", 
            "p95"
        ]
    )

    return d_percentile


def load_data(
    path_a: str, 
    path_b: str, 
    sample_a: str, 
    sample_b: str, 
    features: list[str]
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load and analyze two samples.
    
    This function does the following:
    1. Loads parquet files for both samples
    2. Calculates percentile distributions for each sample
    3. Performs statistical comparisons between samples
    
    Args:
        path_a: Path to parquet file for sample A (e.g., modified sample)
        path_b: Path to parquet file for sample B (e.g., unmodified sample)
        sample_a: Label for sample A
        sample_b: Label for sample B
        features: List of feature names to analyze
        
    Returns:
        Tuple containing:
            - data_percentiles: Combined percentile data for both samples
            - data_stats: Statistical comparison results between samples
    """
    # Load data from parquet files
    d_mod = pd.read_parquet(path_a)
    d_unmod = pd.read_parquet(path_b)

    # Calculate percentiles for each sample
    data_percentiles_mod = data_to_percentiles(d_mod, features)
    data_percentiles_mod["sample"] = sample_a

    data_percentiles_unmod = data_to_percentiles(d_unmod, features)
    data_percentiles_unmod["sample"] = sample_b

    # Combine percentile data
    data_percentiles = pd.concat(
        [data_percentiles_mod, data_percentiles_unmod],
        ignore_index=True
    )

    # Perform statistical comparison
    data_stats = data_to_stats(d_mod, d_unmod, features)

    return data_percentiles, data_stats


def main():
    """
    Main entry point for the statistical analysis script.
    
    Command-line arguments:
        1. stats_data_a: Path to parquet file for sample A
        2. stats_data_b: Path to parquet file for sample B
        3. sample_name_a: Label for sample A
        4. sample_name_b: Label for sample B
        5. features: Comma-separated list of features (e.g., "mean,std,dwell")
        6. outfile_percentiles: Output path for percentiles pickle file
        7. outfile_stats: Output path for statistics pickle file
    """
    if len(sys.argv) != 8:
        print(__doc__)
        print("\nError: Incorrect number of arguments")
        print(f"Expected 7 arguments, got {len(sys.argv) - 1}")
        sys.exit(1)

    stats_data_a = sys.argv[1]
    stats_data_b = sys.argv[2]
    sample_name_a = sys.argv[3]
    sample_name_b = sys.argv[4]
    features = sys.argv[5].split(",")
    outfile_percentiles = sys.argv[6]
    outfile_stats = sys.argv[7]

    print(f"Loading data from {stats_data_a} and {stats_data_b}...")
    print(f"Analyzing features: {', '.join(features)}")

    data_percentiles, data_stats = load_data(
        stats_data_a, 
        stats_data_b, 
        sample_name_a,
        sample_name_b,
        features
    )

    print(f"Saving percentiles to {outfile_percentiles}...")
    with open(outfile_percentiles, "wb") as f:
        pickle.dump(data_percentiles, f)

    print(f"Saving statistics to {outfile_stats}...")
    with open(outfile_stats, "wb") as f:
        pickle.dump(data_stats, f)

if __name__=="__main__":
    main()