from __future__ import annotations

import argparse
from pathlib import Path
from typing import Optional

import pandas as pd


CORE_TOPICS = ["SSQ", "REQ", "AT", "NSE", "TOP", "SEI", "PFV"]

FULL_NAME_TO_ABBR = {
    "Slope and Snow Quality": "SSQ",
    "Rental and Equipment": "REQ",
    "Accessibility and Transport": "AT",
    "Natural Scenery and Environment": "NSE",
    "Ticketing and On-site Process": "TOP",
    "Service Encounter and Instruction": "SEI",
    "Price Fairness and Value": "PFV",
    "Slope & Snow Quality": "SSQ",
    "Rental & Equipment": "REQ",
    "Accessibility & Transport": "AT",
    "Natural Scenery & Environment": "NSE",
    "Ticketing & On-site Process": "TOP",
    "Service Encounter & Instruction": "SEI",
    "Price Fairness & Value": "PFV",
}


def load_table(file_path: Path) -> pd.DataFrame:
    if not file_path.exists():
        raise FileNotFoundError(f"File not found: {file_path}")
    suffix = file_path.suffix.lower()
    if suffix == ".xlsx":
        return pd.read_excel(file_path)
    if suffix == ".csv":
        return pd.read_csv(file_path)
    raise ValueError("Unsupported file format. Please use .xlsx or .csv.")


def save_table(df: pd.DataFrame, file_path: Path) -> None:
    file_path.parent.mkdir(parents=True, exist_ok=True)
    suffix = file_path.suffix.lower()
    if suffix == ".xlsx":
        df.to_excel(file_path, index=False)
    elif suffix == ".csv":
        df.to_csv(file_path, index=False, encoding="utf-8-sig")
    else:
        raise ValueError("Unsupported output format. Please use .xlsx or .csv.")


def normalize_topic_value(value) -> Optional[str]:
    if pd.isna(value):
        return None

    text = str(value).strip()
    if text in CORE_TOPICS:
        return text
    if text in FULL_NAME_TO_ABBR:
        return FULL_NAME_TO_ABBR[text]
    return text


def prepare_topic_column(df: pd.DataFrame, topic_col: str, exclude_noise: bool = True) -> pd.DataFrame:
    if topic_col not in df.columns:
        raise ValueError(f"Topic column '{topic_col}' not found in input file.")

    df = df.copy()
    df["topic_std"] = df[topic_col].apply(normalize_topic_value)
    df = df.loc[df["topic_std"].notna()].copy()

    if exclude_noise:
        df = df.loc[df["topic_std"].astype(str) != "-1"].copy()

    df = df.loc[df["topic_std"].isin(CORE_TOPICS)].copy()
    return df


def construct_destination_level_variables(
    df: pd.DataFrame,
    destination_id_col: str,
    destination_name_col: Optional[str],
    topic_col_std: str,
    sentiment_col: str,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    required_cols = [destination_id_col, topic_col_std, sentiment_col]
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Required column '{col}' not found.")

    working = df.copy()
    working[sentiment_col] = pd.to_numeric(working[sentiment_col], errors="coerce")
    working = working.loc[working[sentiment_col].notna()].copy()
    working[sentiment_col] = working[sentiment_col].clip(lower=0.0, upper=1.0)

    group_keys = [destination_id_col]
    if destination_name_col and destination_name_col in working.columns:
        group_keys.append(destination_name_col)

    total_counts = (
        working.groupby(group_keys, dropna=False)
        .size()
        .reset_index(name="total_valid_segments")
    )

    topic_summary = (
        working.groupby(group_keys + [topic_col_std], dropna=False)
        .agg(
            topic_segment_count=(sentiment_col, "size"),
            topic_sentiment_mean=(sentiment_col, "mean"),
        )
        .reset_index()
    )

    topic_summary = pd.merge(topic_summary, total_counts, on=group_keys, how="left")
    topic_summary["topic_weight"] = topic_summary["topic_segment_count"] / topic_summary["total_valid_segments"]
    topic_summary["topic_eq_contribution"] = topic_summary["topic_weight"] * topic_summary["topic_sentiment_mean"]

    condition_wide = (
        topic_summary.pivot_table(
            index=group_keys,
            columns=topic_col_std,
            values="topic_sentiment_mean",
            aggfunc="first",
        )
        .reset_index()
    )

    count_wide = (
        topic_summary.pivot_table(
            index=group_keys,
            columns=topic_col_std,
            values="topic_segment_count",
            aggfunc="first",
        )
        .reset_index()
    )
    count_wide.columns = [col if col in group_keys else f"{col}_count" for col in count_wide.columns]

    weight_wide = (
        topic_summary.pivot_table(
            index=group_keys,
            columns=topic_col_std,
            values="topic_weight",
            aggfunc="first",
        )
        .reset_index()
    )
    weight_wide.columns = [col if col in group_keys else f"{col}_weight" for col in weight_wide.columns]

    eq_summary = (
        topic_summary.groupby(group_keys, dropna=False)
        .agg(
            EQ=("topic_eq_contribution", "sum"),
            total_valid_segments=("total_valid_segments", "first"),
            observed_topics=(topic_col_std, "nunique"),
        )
        .reset_index()
    )

    final_df = pd.merge(eq_summary, condition_wide, on=group_keys, how="left")
    final_df = pd.merge(final_df, count_wide, on=group_keys, how="left")
    final_df = pd.merge(final_df, weight_wide, on=group_keys, how="left")

    for topic in CORE_TOPICS:
        if topic not in final_df.columns:
            final_df[topic] = pd.NA
        if f"{topic}_count" not in final_df.columns:
            final_df[f"{topic}_count"] = 0
        if f"{topic}_weight" not in final_df.columns:
            final_df[f"{topic}_weight"] = 0.0

    ordered_cols = group_keys + ["total_valid_segments", "observed_topics", "EQ"] + CORE_TOPICS + \
        [f"{topic}_count" for topic in CORE_TOPICS] + [f"{topic}_weight" for topic in CORE_TOPICS]

    final_df = final_df[ordered_cols].copy()
    final_df = final_df.sort_values(by=group_keys).reset_index(drop=True)
    topic_summary = topic_summary.sort_values(by=group_keys + [topic_col_std]).reset_index(drop=True)

    return topic_summary, final_df


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Construct destination-level condition variables and EQ from topic-sentiment review segments."
    )
    parser.add_argument("--input_file", type=str, default="outputs/review_segments_with_topic_and_sentiment.xlsx")
    parser.add_argument("--output_long_file", type=str, default="outputs/destination_topic_summary_long.xlsx")
    parser.add_argument("--output_final_file", type=str, default="outputs/analytical_dataset_before_calibration.xlsx")
    parser.add_argument("--destination_id_col", type=str, default="destination_id")
    parser.add_argument("--destination_name_col", type=str, default="destination_name")
    parser.add_argument("--topic_col", type=str, default="topic")
    parser.add_argument("--sentiment_col", type=str, default="sentiment_score")
    parser.add_argument("--exclude_noise", action="store_true")

    args = parser.parse_args()

    df = load_table(Path(args.input_file))
    destination_name_col = args.destination_name_col if args.destination_name_col in df.columns else None

    prepared_df = prepare_topic_column(
        df=df,
        topic_col=args.topic_col,
        exclude_noise=args.exclude_noise,
    )

    topic_summary_long, final_dataset = construct_destination_level_variables(
        df=prepared_df,
        destination_id_col=args.destination_id_col,
        destination_name_col=destination_name_col,
        topic_col_std="topic_std",
        sentiment_col=args.sentiment_col,
    )

    save_table(topic_summary_long, Path(args.output_long_file))
    save_table(final_dataset, Path(args.output_final_file))

    print("Destination-level variable construction completed successfully.")
    print(f"Number of destinations in final dataset: {len(final_dataset)}")


if __name__ == "__main__":
    main()
