#!/usr/bin/env python
import argparse, os, pod5
import numpy as np
from tqdm import tqdm
from argparse import Namespace
from remora import io, refine_signal_map


def init_parser() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Perform signal-to-sequence alignment with Remora")
    
    parser.add_argument(
        "--bam", type=str, 
        required=True,
        help="Path to a bam file"
    )
    parser.add_argument(
        "--pod5", type=str, 
        required=True,
        help="Path to one/more pod5 file(s)/directories"
    )
    parser.add_argument(
        "--kmer", type=str, 
        required=True,
        help="Path to a kmer-table file"
    )
    parser.add_argument(
        "--out", type=str,
        required=True,
        help="Path to the output file"
    )

    parser.add_argument(
        "--reverse_signal", action="store_true", 
        help="Set if dRNA data"
    )

    parser.add_argument(
        "--alignment_type", type=str,
        choices=["query", "reference", "both"],
        default="query",
        help="Whether to perform query-to-signal, reference-to-signal alignment, or both"
    )
    parser.add_argument(
        "--refinement_iters", type=int, 
        default=2,
        help="Number of refinement iterations"
    )
    parser.add_argument(
        "--refinement_algo", type=str, 
        choices=["viterbi", "dwell_penalty"],
        default="dwell_penalty",
        help="Which refinement algorithm to use", 
    )
    parser.add_argument(
        "--half_bandwidth", type=int,
        default=5,
        help="Number of refinement iterations"
    )
    parser.add_argument(
        "--do_fix_gauge", action="store_true",
        help="Whether to normalize the levels in the kmer table"
    )
    parser.add_argument(
        "--rough_rescale_method", type=str,
        choices=["theil_sen", "least_squares", "none"],
        default="theil_sen",
        help="Which algorithm to use for rough rescaling (only considered if rough rescaling is performed)"
    )
    args = parser.parse_args()
    return args


def adjust_alignment(
    alignment: np.ndarray,
    signal_len: int,
    tags: dict[str, int],
    reverse_signal: bool
):
    """
    Adjusts the alignment to fit to the untrimmed signal.
    """
    try:
        sp = tags["sp"]
    except KeyError:
        sp = 0
    
    try:
        ts = tags["ts"]
    except KeyError:
        ts = 0

    try:
        ns = tags["ns"]
        end = sp + ns
    except KeyError:
        end = signal_len

    offset = signal_len - end if reverse_signal else sp + ts
    return alignment + offset



def run_remora(args: Namespace):
    bam_fh = io.ReadIndexedBam(args.bam)
    pod5_dr = pod5.DatasetReader(args.pod5)
    level_table = os.path.join(args.kmer)

    sig_map_refiner = refine_signal_map.SigMapRefiner(
        kmer_model_filename=level_table,
        do_rough_rescale = args.rough_rescale_method != "none",
        scale_iters = args.refinement_iters,
        algo = args.refinement_algo,
        do_fix_guage = args.do_fix_gauge,
        rough_rescale_method = args.rough_rescale_method,
    )

    with open(args.out, "w") as f:
        for pod5_read in tqdm(pod5_dr.reads()):
            read_id = str(pod5_read.read_id)

            try:
                bam_read = bam_fh.get_first_alignment(read_id)
            except Exception as e:
                pass

            try:
                remora_read = io.Read.from_pod5_and_alignment(
                    pod5_read, 
                    bam_read,
                    reverse_signal=args.reverse_signal
                )
            except Exception as e:
                pass

            try:
                query_to_signal = None
                if args.alignment_type == "query" or args.alignment_type == "both":
                    remora_read.set_refine_signal_mapping(sig_map_refiner, ref_mapping=False)

                    query_to_signal = adjust_alignment(
                        remora_read.query_to_signal,
                        signal_len = pod5_read.signal.size,
                        tags = remora_read._trim_tags,
                        reverse_signal = args.reverse_signal                        
                    )

                reference_to_signal = None
                if args.alignment_type == "reference" or args.alignment_type == "both":
                    remora_read.set_refine_signal_mapping(sig_map_refiner, ref_mapping=True)

                    reference_to_signal = adjust_alignment(
                        remora_read.ref_to_signal,
                        signal_len = pod5_read.signal.size,
                        tags = remora_read._trim_tags,
                        reverse_signal = args.reverse_signal
                    )

                outline = "\t".join([
                    read_id,
                    ",".join([str(e) for e in query_to_signal.tolist()]) if not query_to_signal is None else "None",
                    ",".join([str(e) for e in reference_to_signal.tolist()]) if not reference_to_signal is None else "None"
                ])
                f.write(outline+"\n")
            except Exception as e:
                pass


def main():
    args = init_parser()
    run_remora(args)

if __name__=="__main__":
    main()