import umap, pickle, sys
import numpy as np 
import pandas as pd

RANDOM_STATE = 42

def calc_umap(features: np.ndarray, target: np.ndarray) -> pd.DataFrame:
    """
    Calculating UMAP and returning the results in a DataFrame with 
    columns UMAP1, UMAP2 and sample.
    """
    umap_reducer = umap.UMAP(n_components=2, random_state=RANDOM_STATE)
    umap_res = umap_reducer.fit_transform(features)
    umap_res = pd.DataFrame(umap_res, columns=["UMAP1", "UMAP2"])
    umap_res["sample"] = target
    return umap_res

def load_data(
    interp_path_a: str,
    interp_path_b: str,
    sample: str,
    pickle_out: str
):
    """
    Load data, calculating UMAP and saving the results to pickle
    """
    data_mod = pd.read_parquet(interp_path_a)
    data_unmod = pd.read_parquet(interp_path_b)

    if data_mod.shape[0] > data_unmod.shape[0]:
        print(f"Subsetting modified data from {data_mod.shape[0]} to {data_unmod.shape[0]} reads...")
        data_mod = data_mod.sample(data_unmod.shape[0], random_state=RANDOM_STATE)
    else:
        print(f"Subsetting unmodified data from {data_unmod.shape[0]} to {data_mod.shape[0]} reads...")
        data_unmod = data_unmod.sample(data_mod.shape[0], random_state=RANDOM_STATE)

    data_mod["sample"] = f"{sample.upper()}_mod"
    data_unmod["sample"] = f"{sample.upper()}_unmod"

    data = pd.concat([data_mod, data_unmod]).reset_index(drop=True)
    del(data_mod)
    del(data_unmod)

    n_bases = int(data.columns[-2].lstrip("dwell_")) + 1
    cols_to_drop = [
        "read_id", "start_index_on_read", "region_of_interest", "sample"
    ] + [f"base_{i}" for i in range(n_bases)]

    features = data.drop(cols_to_drop, axis=1).to_numpy()
    target = data["sample"].to_numpy()
    
    print("Calculating UMAP...")
    umap_res = calc_umap(features, target)

    print(f"Saving UMAP results to {pickle_out}...")
    with open(pickle_out, "wb") as f:
        pickle.dump(umap_res, f)


def main():
    """
    Main entry point for the UMAP processing script.
    
    Command-line arguments:
        1. reformated_path_a: Path to parquet file with interpolated data for sample A
        2. reformated_path_b: Path to parquet file with interpolated data for sample B
        7. pickle_out: Output path for pickled UMAP results
        
    Output:
        Pickle file containing a Pandas dataframe with columns UMAP1, UMAP2 and sample.
    """
    if len(sys.argv) != 5:
        print(__doc__)
        print("\nError: Incorrect number of arguments")
        print(f"Expected 4 arguments, got {len(sys.argv) - 1}")
        sys.exit(1)

    reformated_path_a = sys.argv[1]
    reformated_path_b = sys.argv[2]
    sample = sys.argv[3]
    pickle_out = sys.argv[4]

    print("=" * 64)
    print(f"Performing UMAP for sample {sample}")
    print("=" * 64)
    print(f"Interpolated file A: {reformated_path_a}")
    print(f"Interpolated file B: {reformated_path_b}")
    print(f"Output file: {pickle_out}")
    print("=" * 64 + "\n")

    load_data(
        reformated_path_a,
        reformated_path_b,
        sample,
        pickle_out
    )

    print("Finished.")
    print("=" * 64 + "\n")


if __name__=="__main__":
    main()