from __future__ import annotations

import argparse
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


CONDITIONS = ["SSQ", "REQ", "AT", "NSE", "TOP", "SEI", "PFV"]
OUTCOME = "EQ"


def load_table(file_path: Path) -> pd.DataFrame:
    if not file_path.exists():
        raise FileNotFoundError(f"File not found: {file_path}")
    suffix = file_path.suffix.lower()
    if suffix == ".xlsx":
        return pd.read_excel(file_path)
    elif suffix == ".csv":
        return pd.read_csv(file_path)
    raise ValueError("Unsupported input format. Please use .xlsx or .csv.")


def save_table(df: pd.DataFrame, file_path: Path) -> None:
    file_path.parent.mkdir(parents=True, exist_ok=True)
    suffix = file_path.suffix.lower()
    if suffix == ".xlsx":
        df.to_excel(file_path, index=False)
    elif suffix == ".csv":
        df.to_csv(file_path, index=False, encoding="utf-8-sig")
    else:
        raise ValueError("Unsupported output format. Please use .xlsx or .csv.")


def prepare_input_data(df: pd.DataFrame) -> pd.DataFrame:
    required_cols = ["destination_name", OUTCOME] + CONDITIONS
    missing_cols = [c for c in required_cols if c not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns after preparation: {missing_cols}")

    df = df[required_cols].copy()
    for col in [OUTCOME] + CONDITIONS:
        df[col] = pd.to_numeric(df[col], errors="coerce")
    return df.dropna(subset=[OUTCOME] + CONDITIONS).reset_index(drop=True)


def compute_ce_fdh_frontier(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    order = np.argsort(x, kind="mergesort")
    x_sorted = x[order]
    y_sorted = y[order]

    unique_x = []
    y_at_unique_x = []
    current_x = None
    current_y_max = None

    for xi, yi in zip(x_sorted, y_sorted):
        if current_x is None or xi != current_x:
            if current_x is not None:
                unique_x.append(current_x)
                y_at_unique_x.append(current_y_max)
            current_x = xi
            current_y_max = yi
        else:
            current_y_max = max(current_y_max, yi)

    unique_x.append(current_x)
    y_at_unique_x.append(current_y_max)

    unique_x = np.array(unique_x, dtype=float)
    y_at_unique_x = np.array(y_at_unique_x, dtype=float)
    ceiling_y = np.maximum.accumulate(y_at_unique_x)
    return unique_x, ceiling_y


def compute_ce_fdh_effect_size(x: np.ndarray, y: np.ndarray) -> Dict[str, float]:
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)

    xmin, xmax = float(np.min(x)), float(np.max(x))
    ymin, ymax = float(np.min(y)), float(np.max(y))
    scope = (xmax - xmin) * (ymax - ymin)

    if scope <= 0:
        return {"d": np.nan, "scope": scope, "empty_space": np.nan, "xmin": xmin, "xmax": xmax, "ymin": ymin, "ymax": ymax}

    frontier_x, frontier_y = compute_ce_fdh_frontier(x, y)

    empty_space = 0.0
    for i in range(len(frontier_x) - 1):
        dx = frontier_x[i + 1] - frontier_x[i]
        empty_space += max(0.0, ymax - frontier_y[i]) * dx

    d = empty_space / scope
    return {
        "d": float(d), "scope": float(scope), "empty_space": float(empty_space),
        "xmin": xmin, "xmax": xmax, "ymin": ymin, "ymax": ymax,
    }


def permutation_test_ce_fdh(x: np.ndarray, y: np.ndarray, n_permutations: int = 1000, random_state: int = 42) -> float:
    rng = np.random.default_rng(random_state)
    observed = compute_ce_fdh_effect_size(x, y)["d"]

    permuted_ds = []
    for _ in range(n_permutations):
        y_perm = rng.permutation(y)
        permuted_ds.append(compute_ce_fdh_effect_size(x, y_perm)["d"])

    permuted_ds = np.array(permuted_ds, dtype=float)
    return float((np.sum(permuted_ds >= observed) + 1) / (len(permuted_ds) + 1))


def compute_bottleneck_table(x: np.ndarray, y: np.ndarray, x_name: str, y_name: str, n_levels: int = 10) -> pd.DataFrame:
    stats = compute_ce_fdh_effect_size(x, y)
    frontier_x, frontier_y = compute_ce_fdh_frontier(x, y)

    xmin, xmax = stats["xmin"], stats["xmax"]
    ymin, ymax = stats["ymin"], stats["ymax"]
    outcome_levels = np.linspace(ymin, ymax, n_levels + 1)

    rows = []
    for level in outcome_levels:
        valid_idx = np.where(frontier_y >= level)[0]
        required_x = np.nan if len(valid_idx) == 0 else frontier_x[valid_idx[0]]
        required_x_pct = (required_x - xmin) / (xmax - xmin) * 100 if xmax > xmin and pd.notna(required_x) else np.nan
        outcome_pct = (level - ymin) / (ymax - ymin) * 100 if ymax > ymin else np.nan

        rows.append({
            "Condition": x_name,
            "Outcome": y_name,
            "Outcome_level_raw": round(float(level), 6),
            "Outcome_level_pct_of_range": round(float(outcome_pct), 2),
            "Required_condition_raw": np.nan if pd.isna(required_x) else round(float(required_x), 6),
            "Required_condition_pct_of_range": np.nan if pd.isna(required_x_pct) else round(float(required_x_pct), 2),
        })

    return pd.DataFrame(rows)


def interpret_effect_size(d: float) -> str:
    if pd.isna(d):
        return "NA"
    if d < 0.1:
        return "Small"
    elif d < 0.3:
        return "Medium"
    elif d < 0.5:
        return "Large"
    else:
        return "Very strong"


def plot_ceiling_line(x: np.ndarray, y: np.ndarray, x_name: str, y_name: str, output_file: Path) -> None:
    frontier_x, frontier_y = compute_ce_fdh_frontier(x, y)

    plt.figure(figsize=(6, 4.5))
    plt.scatter(x, y, alpha=0.7, s=20)
    plt.step(frontier_x, frontier_y, where="post", linewidth=2)
    plt.xlabel(x_name)
    plt.ylabel(y_name)
    plt.title(f"NCA ceiling line: {x_name} -> {y_name}")
    plt.tight_layout()

    output_file.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_file, dpi=300)
    plt.close()


def run_nca_analysis(df: pd.DataFrame, outcome_col: str, condition_cols: List[str], n_permutations: int, random_state: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    result_rows = []
    bottleneck_tables = []
    y = df[outcome_col].to_numpy(dtype=float)

    for cond in condition_cols:
        x = df[cond].to_numpy(dtype=float)
        effect_stats = compute_ce_fdh_effect_size(x, y)
        p_value = permutation_test_ce_fdh(x=x, y=y, n_permutations=n_permutations, random_state=random_state)

        result_rows.append({
            "Condition": cond,
            "Effect_size_d_CE_FDH": round(effect_stats["d"], 6),
            "p_value": round(p_value, 6),
            "Interpretation": interpret_effect_size(effect_stats["d"]),
            "Scope": round(effect_stats["scope"], 6),
            "Empty_space": round(effect_stats["empty_space"], 6),
            "N_cases": len(df),
        })

        bottleneck_tables.append(
            compute_bottleneck_table(x=x, y=y, x_name=cond, y_name=outcome_col, n_levels=10)
        )

    result_df = pd.DataFrame(result_rows).sort_values(by="Effect_size_d_CE_FDH", ascending=False).reset_index(drop=True)
    bottleneck_all = pd.concat(bottleneck_tables, axis=0, ignore_index=True)
    return result_df, bottleneck_all


def main() -> None:
    parser = argparse.ArgumentParser(description="CE-FDH Necessary Condition Analysis for ski destination data.")
    parser.add_argument("--input_file", type=str, default="outputs/analytical_dataset_before_calibration.xlsx")
    parser.add_argument("--output_results_file", type=str, default="outputs/NCA_results.xlsx")
    parser.add_argument("--output_bottleneck_file", type=str, default="outputs/NCA_bottleneck_table.xlsx")
    parser.add_argument("--plot_dir", type=str, default="outputs/NCA_ceiling_line_plots")
    parser.add_argument("--n_permutations", type=int, default=1000)
    parser.add_argument("--random_state", type=int, default=42)
    args = parser.parse_args()

    prepared_df = prepare_input_data(load_table(Path(args.input_file)))
    result_df, bottleneck_df = run_nca_analysis(
        df=prepared_df,
        outcome_col=OUTCOME,
        condition_cols=CONDITIONS,
        n_permutations=args.n_permutations,
        random_state=args.random_state,
    )

    save_table(result_df, Path(args.output_results_file))
    save_table(bottleneck_df, Path(args.output_bottleneck_file))

    for cond in CONDITIONS:
        plot_ceiling_line(
            x=prepared_df[cond].to_numpy(dtype=float),
            y=prepared_df[OUTCOME].to_numpy(dtype=float),
            x_name=cond,
            y_name=OUTCOME,
            output_file=Path(args.plot_dir) / f"NCA_ceiling_line_{cond}.png",
        )

    print("NCA analysis completed successfully.")
    print(f"Number of cases: {len(prepared_df)}")


if __name__ == "__main__":
    main()
