"""
Aligned signal processing module

This module processes nanopore sequence-to-signal alignments with the signal itself,
extracting the signal chunks for bases of interest. It prepares the data for visualization

Usage:
    python script.py <alignment> <pod5> <position> <ref_name> <window> <max_reads> <output>

Example:
    python script.py aligned.parquet data.pod5 5000 chr1 30 1000 signals.pkl
"""

import sys
import pickle
from typing import Optional

import numpy as np
import pandas as pd
import pod5 as p5
from tqdm import tqdm

# Type alias for processed signal data structure
# Each element is a tuple of (signal_chunks, x_coordinates, read_id)
ProcessedSignals = list[tuple[list[np.ndarray], list[np.ndarray], str]]


def check_contains_full_region(
    ref_name: str, 
    ref_start: int, 
    ref_end: int,
    target_position: Optional[int] = None, 
    window_size: Optional[int] = None,
    target_ref: Optional[str] = None
) -> bool:
    """
    Verify that a read fully contains a target genomic region.
    
    This function checks whether a sequencing read spans the entire region of
    interest, ensuring that extracted signals won't be truncated at boundaries.
    
    Args:
        ref_name: Reference sequence name (e.g., 'chr1', 'chromosome_1')
        ref_start: Start position of the read alignment on the reference
        ref_end: End position of the read alignment on the reference
        target_position: Center position of the region of interest (optional)
        window_size: Half-width of the region around target_position (optional)
        target_ref: Expected reference sequence name (optional)
        
    Returns:
        True if the read fully contains the target region, False otherwise.
        Returns True if target_position or window_size is None (no filtering).    
    """
    # No filtering if target parameters not provided
    if target_position is None or window_size is None:
        return True
        
    # Check reference name matches
    if target_ref is not None and ref_name != target_ref:
        return False
    
    # Calculate target region boundaries 
    target_start = target_position - window_size
    target_end = target_position + window_size
    
    # Read must fully contain the target region
    return ref_start <= target_start and ref_end >= target_end


def standardize(signal: np.ndarray) -> np.ndarray:
    """
    Apply z-score standardization to normalize signal data.
    
    Z-score standardization transforms the signal to have zero mean and unit
    variance, making signals comparable across different reads and conditions.
    
    Formula: z = (x - μ) / σ
    
    Args:
        signal: Raw nanopore signal array
        
    Returns:
        Standardized signal array with mean=0 and std=1.
        Returns zeros if the signal is constant (std=0) or invalid.
    """
    mean = np.mean(signal)
    std = np.std(signal)

    if std == 0 or np.isnan(std):
        return np.zeros_like(signal, dtype=float)
    
    return (signal - mean) / std


def process_signals(
        fishnet_data: pd.DataFrame, 
        pod5_data: p5.DatasetReader,
        target_position: int,
        target_ref: str,
        window_size: int,
        max_reads: Optional[int] = None,
    ) -> ProcessedSignals:
    """
    Extract and process signals for a specific genomic region.
    
    This function performs the core signal processing pipeline:
    1. Filters reads that fully contain the target region
    2. Extracts the relevant signal segments using alignment information
    3. Standardizes signals using z-score normalization
    4. Filters out reads with extreme signal values (outliers)
    5. Splits signals into base-level chunks with interpolated x-coordinates
       for an even x-range for each base
    
    Args:
        fishnet_data: DataFrame containing alignment information with columns:
            - read_id: Unique identifier for each read
            - ref_name: Reference sequence name
            - ref_start: Start position on reference
            - ref_end: End position on reference (optional, calculated if missing)
            - ref_sequence: Reference sequence string
            - ref_to_signal: Array mapping reference positions to signal indices
        pod5_data: POD5 dataset reader for accessing raw signal data
        target_position: Genomic position of interest (center of extraction window)
        target_ref: Reference sequence name to filter for
        window_size: Half-width of the extraction window around target_position
        max_reads: Maximum number of reads to process (None for unlimited)
        
    Returns:
        List of tuples, each containing:
            - signal_split: List of signal arrays, one per base in the region
            - x_vals_split: List of coordinate arrays for plotting/interpolation
            - read_id: Read identifier string
            
    Notes:
        - Signals are reversed ([::-1]) to match standard orientation (working with RNA data)
        - Reads with extreme standardized values (|z| > 5) are filtered out
        - The extraction window is [target_position - window_size, target_position + window_size]
        - Progress is displayed using tqdm progress bar
        - Filters reads that don't fully span the target region
    """
    # Calculate the target region boundaries
    target_start = target_position - window_size
    target_end = target_position + window_size

    signal_chunks = []
    n_reads_processed = 0

    print(f"Target region: {target_ref}:{target_start}-{target_end}")
    print(f"Window size: +/-{window_size} bases around position {target_position}")

    with tqdm(desc="Processing reads") as progress:
        for row in fishnet_data.itertuples():
            # Get reference coordinates
            ref_start = int(row.ref_start)
            ref_end = getattr(row, 'ref_end', ref_start + len(row.ref_sequence) - 1)
            
            # Check if read contains the full target region
            if not check_contains_full_region(
                row.ref_name, ref_start, ref_end,
                target_position, window_size, target_ref
            ):
                continue

            # Calculate offset to extract target region from alignment
            offset_start = target_start - ref_start
            offset_end = target_end - ref_start + 2

            # Extract signal indices for the target region from the alignment
            alignment_region = row.ref_to_signal[offset_start:offset_end]

            # Load and reverse signal (POD5 stores RNA signals in 3'-5')
            signal = pod5_data.get_read(row.read_id).signal[::-1]
            signal_standardized = standardize(signal)

            # Skip reads with outliers beyond 5 standard deviations
            signal_region = signal_standardized[alignment_region[0]:alignment_region[-1]]
            if np.any(np.abs(signal_region) > 5):
                continue

            # Split signal into base-level chunks
            signal_split = np.split(signal_standardized, alignment_region)[1:-1]
            
            # Create normalized x-coordinates for each base (0 to 1, 1 to 2, etc.)
            x_vals_split_equal_size = []
            for i, split in enumerate(signal_split):
                split_len = len(split)
                x_min = i
                x_max = i + 1
                # Linearly spaced coordinates within each base interval
                x_vals = np.linspace(x_min, x_max, split_len)
                x_vals_split_equal_size.append(np.array(x_vals))
            
            # Store processed signal data
            signal_chunks.append((signal_split, x_vals_split_equal_size, row.read_id))

            progress.update()
            n_reads_processed += 1

            # Stop if maximum number of reads reached
            if max_reads and n_reads_processed >= max_reads:
                break
                
        return signal_chunks
    

def get_start_end_idx(ps: ProcessedSignals, dist_from_center) -> tuple[int, int]:
    """
    Calculate start and end indices for a symmetric window around the center.
    
    This helper function determines which bases to include when extracting
    a sub-region centered around the middle of the processed signals.
    
    Args:
        ps: Processed signals data structure
        dist_from_center: Number of bases to include on each side of center
        
    Returns:
        Tuple of (start_index, end_index) defining the extraction window

    Notes:
        - Assumes the first element of ps contains the signal structure
        - end_idx is exclusive (Python slice convention)
    """
    center_idx = len(ps[0][0]) // 2
    start_idx = center_idx - dist_from_center
    end_idx = center_idx + dist_from_center + 1
    return start_idx, end_idx


def main():
    """
    Main entry point for the signal processing script.
    
    Command-line arguments:
        1. alignment_path: Path to parquet file with alignment data (from Fishnet)
        2. pod5_path: Path to POD5 file containing raw signal data
        3. position_of_interest: Genomic position to center the extraction window
        4. reference_seq_name: Reference sequence name (e.g., 'chr1')
        5. half_window_size: Half-width of extraction window (bases on each side)
        6. max_reads: Maximum number of reads to process (or 0 for unlimited)
        7. pickle_out: Output path for pickled processed signals
        
    Output:
        Pickle file containing ProcessedSignals structure:
        List of (signal_chunks, x_coordinates, read_id) tuples
    """
    if len(sys.argv) != 8:
        print(__doc__)
        print("\nError: Incorrect number of arguments")
        print(f"Expected 7 arguments, got {len(sys.argv) - 1}")
        sys.exit(1)

    alignment_path = sys.argv[1]
    pod5_path = sys.argv[2]
    position_of_interest = int(sys.argv[3])
    reference_seq_name = sys.argv[4]
    half_window_size = int(sys.argv[5])
    max_reads_arg = int(sys.argv[6])
    pickle_out = sys.argv[7]

    # Convert 0 to None for unlimited reads
    max_reads = None if max_reads_arg == 0 else max_reads_arg

    print("=" * 64)
    print("Preparing sequence-to-signal alignments and signals for plotting")
    print("=" * 64)
    print(f"Alignment file: {alignment_path}")
    print(f"POD5 file: {pod5_path}")
    print(f"Target: {reference_seq_name}:{position_of_interest}")
    print(f"Window: +/-{half_window_size} bases")
    print(f"Max reads: {'unlimited' if max_reads is None else max_reads}")
    print("=" * 64 + "\n")

    print("Loading alignment data...")
    alignment_df = pd.read_parquet(alignment_path)
    print(f"Loaded {len(alignment_df)} alignments")

    print("Opening POD5 dataset...")
    pod5_dataset = p5.DatasetReader(pod5_path)

    print("Processing aligned signals...\n")
    
    print("Processing aligned signals...")
    processed_signals = process_signals(
        alignment_df,
        pod5_dataset,
        position_of_interest,
        reference_seq_name,
        half_window_size,
        max_reads
    )

    print(f"Saving processed signals to {pickle_out}...")
    with open(pickle_out, "wb") as file:
        pickle.dump(processed_signals, file)

    print("Finished.")


if __name__=="__main__":
    main()