import sys
import pandas as pd
import numpy as np
from tqdm import tqdm

def compare_alignments(fn_algn: np.ndarray, rm_algn: np.ndarray) -> list[float]:
    """
    Takes two 1D alignments, returns in order:
    1. norm_mean_diff
    2. normalized_max_diff
    3. pct_identical_boundaries
    """
    if len(fn_algn) != len(rm_algn):
        raise Exception("Alignment length mismatch")

    fn_algn = fn_algn.astype(np.int64)
    rm_algn = rm_algn.astype(np.int64)

    total_signal_span = max(fn_algn[-1], rm_algn[-1]) - min(fn_algn[0], rm_algn[0])

    boundary_diffs = []
    for i in range(len(fn_algn)):
        diff = np.abs(fn_algn[i] - rm_algn[i])
        boundary_diffs.append(diff)

    return [
        np.mean(boundary_diffs) / total_signal_span,
        max(boundary_diffs) / total_signal_span,
        sum(1 for d in boundary_diffs if d==0) / len(boundary_diffs)
    ]


def main():
    in_parquet = sys.argv[1]
    out_file = sys.argv[2]

    df = pd.read_parquet(in_parquet)

    read_ids = []
    results = []

    for row in tqdm(df.itertuples()):
        try:
            result = compare_alignments(
                row.alignment_fishnet,
                row.alignment_remora
            )

            read_ids.append(row.read_id)
            results.append(result)
        except:
            print(f"Skipped {row.read_id}")
            continue

    read_ids = np.array(read_ids)
    results = np.array(results)


    data = pd.DataFrame(results, columns=[
        "norm_mean_diff",
        "norm_max_diff",
        "pct_identical_boundaries"
    ])
    data["read_id"] = read_ids

    data.to_parquet(out_file)


if __name__=="__main__":
    main()