from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Optional

import pandas as pd
from transformers import pipeline


def load_table(file_path: Path) -> pd.DataFrame:
    if not file_path.exists():
        raise FileNotFoundError(f"File not found: {file_path}")

    suffix = file_path.suffix.lower()
    if suffix == ".xlsx":
        return pd.read_excel(file_path)
    if suffix == ".csv":
        return pd.read_csv(file_path)
    raise ValueError("Unsupported file format. Please use .xlsx or .csv.")


def save_table(df: pd.DataFrame, file_path: Path) -> None:
    file_path.parent.mkdir(parents=True, exist_ok=True)
    suffix = file_path.suffix.lower()

    if suffix == ".xlsx":
        df.to_excel(file_path, index=False)
    elif suffix == ".csv":
        df.to_csv(file_path, index=False, encoding="utf-8-sig")
    else:
        raise ValueError("Unsupported output format. Please use .xlsx or .csv.")


def standardize_text(text: str) -> str:
    if pd.isna(text):
        return ""
    return str(text).strip()


def merge_topic_assignments(
    segment_df: pd.DataFrame,
    topic_df: pd.DataFrame,
    segment_id_column: str = "segment_id",
    segment_text_column: str = "text_segment",
    topic_segment_id_column: Optional[str] = "segment_id",
    topic_text_column: Optional[str] = "text_segment",
    topic_column: str = "topic",
    topic_prob_column: Optional[str] = "topic_probability",
) -> pd.DataFrame:
    """
    Merge topic assignments into the review-segment dataset.

    Preferred order:
    1. Merge by segment_id.
    2. If segment_id is unavailable, merge by normalized text.
    """
    segment_df = segment_df.copy()
    topic_df = topic_df.copy()

    if topic_column not in topic_df.columns:
        raise ValueError(f"Topic column '{topic_column}' not found in topic assignment file.")

    if (
        topic_segment_id_column
        and topic_segment_id_column in topic_df.columns
        and segment_id_column in segment_df.columns
    ):
        keep_cols = [topic_segment_id_column, topic_column]
        if topic_prob_column and topic_prob_column in topic_df.columns:
            keep_cols.append(topic_prob_column)

        topic_sub = topic_df[keep_cols].copy()
        topic_sub = topic_sub.rename(columns={topic_segment_id_column: segment_id_column})
        return pd.merge(segment_df, topic_sub, on=segment_id_column, how="left")

    if topic_text_column is None or topic_text_column not in topic_df.columns:
        raise ValueError("Could not merge by segment_id, and no valid topic text column was provided.")

    if segment_text_column not in segment_df.columns:
        raise ValueError(f"Segment text column '{segment_text_column}' not found in segment file.")

    segment_df["_merge_text"] = segment_df[segment_text_column].apply(standardize_text)
    topic_df["_merge_text"] = topic_df[topic_text_column].apply(standardize_text)

    keep_cols = ["_merge_text", topic_column]
    if topic_prob_column and topic_prob_column in topic_df.columns:
        keep_cols.append(topic_prob_column)

    topic_sub = topic_df[keep_cols].drop_duplicates(subset=["_merge_text"]).copy()
    merged = pd.merge(segment_df, topic_sub, on="_merge_text", how="left")
    return merged.drop(columns=["_merge_text"])


def build_sentiment_pipeline(model_name: str):
    """
    Build a Hugging Face sentiment-analysis pipeline.
    """
    return pipeline(
        task="sentiment-analysis",
        model=model_name,
        tokenizer=model_name,
        truncation=True,
    )


def map_label_to_sentiment_score(label: str, score: float) -> float:
    """
    Map model output to a continuous sentiment score in [0, 1].
    """
    if label is None:
        return 0.5

    label_upper = str(label).strip().upper()

    positive_labels = {"POSITIVE", "LABEL_1", "POS", "5 STARS", "4 STARS"}
    negative_labels = {"NEGATIVE", "LABEL_0", "NEG", "1 STAR", "2 STARS"}

    if label_upper in positive_labels:
        return float(score)
    if label_upper in negative_labels:
        return float(1 - score)
    if "3 STAR" in label_upper or "NEUTRAL" in label_upper:
        return 0.5

    return 0.5


def compute_sentiment_scores(
    df: pd.DataFrame,
    text_column: str,
    model_name: str,
    batch_size: int = 32,
) -> pd.DataFrame:
    if text_column not in df.columns:
        raise ValueError(f"Text column '{text_column}' not found.")

    classifier = build_sentiment_pipeline(model_name=model_name)
    df = df.copy()
    texts = df[text_column].fillna("").astype(str).tolist()

    labels = []
    confidences = []
    sentiment_scores = []

    for start in range(0, len(texts), batch_size):
        batch_texts = texts[start:start + batch_size]
        results = classifier(batch_texts)

        for result in results:
            label = result.get("label")
            confidence = float(result.get("score", 0.5))
            sentiment_score = map_label_to_sentiment_score(label, confidence)

            labels.append(label)
            confidences.append(confidence)
            sentiment_scores.append(sentiment_score)

    df["sentiment_label"] = labels
    df["sentiment_confidence"] = confidences
    df["sentiment_score"] = sentiment_scores
    return df


def main() -> None:
    parser = argparse.ArgumentParser(description="Topic-sentiment scoring pipeline for review segments.")
    parser.add_argument("--segment_file", type=str, default="outputs/review_segments.xlsx")
    parser.add_argument("--topic_file", type=str, default="outputs/bertopic_results/document_topic_assignments.xlsx")
    parser.add_argument("--output_file", type=str, default="outputs/review_segments_with_topic_and_sentiment.xlsx")
    parser.add_argument("--segment_id_column", type=str, default="segment_id")
    parser.add_argument("--segment_text_column", type=str, default="text_segment")
    parser.add_argument("--topic_segment_id_column", type=str, default="segment_id")
    parser.add_argument("--topic_text_column", type=str, default="text_segment")
    parser.add_argument("--topic_column", type=str, default="topic")
    parser.add_argument("--topic_probability_column", type=str, default="topic_probability")
    parser.add_argument(
        "--model_name",
        type=str,
        default="uer/roberta-base-finetuned-jd-binary-chinese",
        help="Sentiment model used in the study.",
    )
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--output_config", type=str, default="outputs/topic_sentiment_model_config.json")

    args = parser.parse_args()

    segment_df = load_table(Path(args.segment_file))
    topic_df = load_table(Path(args.topic_file))

    merged_df = merge_topic_assignments(
        segment_df=segment_df,
        topic_df=topic_df,
        segment_id_column=args.segment_id_column,
        segment_text_column=args.segment_text_column,
        topic_segment_id_column=args.topic_segment_id_column if args.topic_segment_id_column else None,
        topic_text_column=args.topic_text_column,
        topic_column=args.topic_column,
        topic_prob_column=args.topic_probability_column,
    )

    scored_df = compute_sentiment_scores(
        df=merged_df,
        text_column=args.segment_text_column,
        model_name=args.model_name,
        batch_size=args.batch_size,
    )

    save_table(scored_df, Path(args.output_file))

    config = {
        "model_name": args.model_name,
        "merge_priority": "segment_id first, normalized text second",
        "text_column": args.segment_text_column,
        "topic_column": args.topic_column,
    }
    Path(args.output_config).parent.mkdir(parents=True, exist_ok=True)
    Path(args.output_config).write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")

    print("Topic sentiment scoring completed successfully.")
    print(f"Output file: {args.output_file}")
    print(f"Rows without topic assignment: {scored_df[args.topic_column].isna().sum()}")


if __name__ == "__main__":
    main()
