from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Optional

import pandas as pd
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from umap import UMAP
from hdbscan import HDBSCAN


def load_input_data(file_path: Path, text_column: str, id_column: Optional[str] = None) -> pd.DataFrame:
    """
    Load review-segment data from Excel or CSV.
    If an ID column is provided, keep it for downstream merges.
    """
    if not file_path.exists():
        raise FileNotFoundError(f"Input file not found: {file_path}")

    if file_path.suffix.lower() == ".xlsx":
        df = pd.read_excel(file_path)
    elif file_path.suffix.lower() == ".csv":
        df = pd.read_csv(file_path)
    else:
        raise ValueError("Unsupported file type. Please use .xlsx or .csv")

    keep_cols = [text_column]
    if id_column:
        if id_column not in df.columns:
            raise ValueError(f"Column '{id_column}' not found in input data.")
        keep_cols.insert(0, id_column)

    if text_column not in df.columns:
        raise ValueError(f"Column '{text_column}' not found in input data.")

    df = df[keep_cols].copy()
    df[text_column] = df[text_column].astype(str).str.strip()
    df = df[df[text_column] != ""].dropna(subset=[text_column]).reset_index(drop=True)
    return df


def build_topic_model(
    min_topic_size: int = 100,
    n_neighbors: int = 15,
    n_components: int = 5,
    min_dist: float = 0.0,
    random_state: int = 42,
) -> BERTopic:
    """
    Build a BERTopic model using Sentence-BERT + UMAP + HDBSCAN.
    """
    embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

    umap_model = UMAP(
        n_neighbors=n_neighbors,
        n_components=n_components,
        min_dist=min_dist,
        metric="cosine",
        random_state=random_state,
    )

    hdbscan_model = HDBSCAN(
        min_cluster_size=min_topic_size,
        metric="euclidean",
        cluster_selection_method="eom",
        prediction_data=True,
    )

    vectorizer_model = CountVectorizer(
        stop_words=None,
        ngram_range=(1, 2),
        min_df=5,
    )

    return BERTopic(
        embedding_model=embedding_model,
        umap_model=umap_model,
        hdbscan_model=hdbscan_model,
        vectorizer_model=vectorizer_model,
        language="multilingual",
        calculate_probabilities=True,
        verbose=True,
    )


def save_topic_outputs(
    topic_model: BERTopic,
    df_input: pd.DataFrame,
    text_column: str,
    id_column: Optional[str],
    topics: list[int],
    probs,
    output_dir: Path,
) -> None:
    """
    Save model outputs for peer-review inspection and reproducibility.
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    topic_info = topic_model.get_topic_info()
    topic_info.to_excel(output_dir / "topic_info.xlsx", index=False)

    doc_df = pd.DataFrame({
        text_column: df_input[text_column].tolist(),
        "topic": topics,
    })

    if id_column:
        doc_df.insert(0, id_column, df_input[id_column].tolist())

    if probs is not None:
        try:
            doc_df["topic_probability"] = probs.max(axis=1)
        except Exception:
            pass

    doc_df.to_excel(output_dir / "document_topic_assignments.xlsx", index=False)

    topic_keywords = []
    for topic_id in sorted(set(topics)):
        if topic_id == -1:
            continue
        terms = topic_model.get_topic(topic_id)
        if terms is None:
            continue
        for rank, (word, weight) in enumerate(terms, start=1):
            topic_keywords.append({
                "topic": topic_id,
                "rank": rank,
                "word": word,
                "weight": weight,
            })
    pd.DataFrame(topic_keywords).to_excel(output_dir / "topic_keywords.xlsx", index=False)

    try:
        rep_docs = topic_model.get_representative_docs()
        records = []
        for topic_id, texts in rep_docs.items():
            if texts is None:
                continue
            for i, text in enumerate(texts, start=1):
                records.append({
                    "topic": topic_id,
                    "example_id": i,
                    "representative_text": text,
                })
        pd.DataFrame(records).to_excel(output_dir / "representative_documents.xlsx", index=False)
    except Exception:
        pass

    topic_model.save(output_dir / "bertopic_model")

    params = {
        "embedding_model": "all-MiniLM-L6-v2",
        "language": "multilingual",
        "calculate_probabilities": True,
        "min_topic_size": None,
        "notes": "The final number of analytical themes can be obtained by semantic consolidation of BERTopic outputs."
    }
    with open(output_dir / "model_parameters.json", "w", encoding="utf-8") as f:
        json.dump(params, f, ensure_ascii=False, indent=2)


def main() -> None:
    parser = argparse.ArgumentParser(description="BERTopic pipeline for peer-review materials.")
    parser.add_argument(
        "--input_file",
        type=str,
        default="data/Data_3_segmented_reviews.xlsx",
        help="Path to the cleaned review-segment dataset.",
    )
    parser.add_argument(
        "--text_column",
        type=str,
        default="text_segment",
        help="Name of the text column in the input data.",
    )
    parser.add_argument(
        "--id_column",
        type=str,
        default="segment_id",
        help="Optional segment ID column for downstream merges.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs/bertopic_results",
        help="Directory to save outputs.",
    )
    parser.add_argument(
        "--min_topic_size",
        type=int,
        default=100,
        help="Minimum cluster size for HDBSCAN.",
    )

    args = parser.parse_args()

    input_file = Path(args.input_file)
    output_dir = Path(args.output_dir)

    df = load_input_data(
        file_path=input_file,
        text_column=args.text_column,
        id_column=args.id_column if args.id_column else None,
    )
    docs = df[args.text_column].tolist()

    topic_model = build_topic_model(min_topic_size=args.min_topic_size)
    topics, probs = topic_model.fit_transform(docs)

    save_topic_outputs(
        topic_model=topic_model,
        df_input=df,
        text_column=args.text_column,
        id_column=args.id_column if args.id_column else None,
        topics=topics,
        probs=probs,
        output_dir=output_dir,
    )

    print("BERTopic analysis completed successfully.")
    print(f"Input file: {input_file}")
    print(f"Output directory: {output_dir}")


if __name__ == "__main__":
    main()
