import sys
import os
import pysam
import numpy as np
from tqdm import tqdm
from pathlib import Path
import pod5 as p5
import logging

# Set up logging
logging.basicConfig(filename="subset_full_data.log", level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def determine_boundary_read_lengths(bam_path: str) -> tuple[int, int]:
    """
    Determine the read lengths to use as boundaries
    to split the bam file by the read length into
    three groups (short, medium, long reads).
    This is done by calculating the 33rd and 66th
    percentile of the read lengths of mapped reads.
    The reads will then later be split into the following
    groups:
    - short: 100 <= read < boundary1
    - medium: boundary1 <= read < boundary2
    - long: boundary2 <= read
    """
    logger.info(f"Determining read length boundaries from {bam_path}")
    
    bam = pysam.AlignmentFile(bam_path, "rb")
    read_lengths = []
    
    for read in tqdm(bam, desc="Reading BAM file"):
        # Only consider PRIMARY mapped reads for boundary calculation
        if (
            not read.is_unmapped and 
            not read.is_secondary and 
            not read.is_supplementary and
            read.query_length >= 100
        ):
            read_lengths.append(read.query_length)

    bam.close()
    
    read_lengths = np.array(read_lengths)
    boundary1 = int(round(np.percentile(read_lengths, 33.33)))
    boundary2 = int(round(np.percentile(read_lengths, 66.67)))
    
    logger.info(f"Read length boundaries: short < {boundary1}, medium < {boundary2}, long >= {boundary2}")
    logger.info(f"Total mapped reads >= 100bp: {len(read_lengths)}")
    
    return boundary1, boundary2


def get_read_category(read_length: int, boundary1: int, boundary2: int) -> str:
    """Categorize read based on length boundaries."""
    if read_length < boundary1:
        return "short"
    elif read_length < boundary2:
        return "medium"
    else:
        return "long"


class SubsetManager:
    """Manages multiple subsets for different read counts and categories."""
    
    def __init__(self, output_dir: str, read_counts: list[int], bam):
        self.output_dir = Path(output_dir)
        self.read_counts = sorted(read_counts)  # Ensure sorted for efficient processing
        self.categories = ["short", "medium", "long"]
        
        # Track counts for each subset
        self.counts = {}
        self.bam_writers = {}
        self.pod5_writers = {}
        self.finished = {}
        self.written_reads = {}

        # Initialize all subsets
        self._initialize_subsets(bam)
    
    def _initialize_subsets(self, template):
        """Initialize output directories and file writers for all subsets."""
        for category in self.categories:
            self.counts[category] = {}
            self.bam_writers[category] = {}
            self.pod5_writers[category] = {}
            self.finished[category] = {}
            self.written_reads[category] = {}
            
            for count in self.read_counts:
                self.counts[category][count] = 0
                
                # Create output directory
                subset_dir = self.output_dir / f"{category}_{count}"
                subset_dir.mkdir(parents=True, exist_ok=True)

                bam_path = subset_dir / "subset.bam"
                self.bam_writers[category][count] = pysam.AlignmentFile(
                    str(bam_path), "wb", template=template
                )

                pod5_path = subset_dir / "subset.pod5"
                self.pod5_writers[category][count] = p5.Writer(str(pod5_path))

                self.finished[category][count] = False
                self.written_reads[category][count] = []
    
    def _get_active_counts(self, category: str) -> list[int]:
        """Get list of read counts that still need more reads for this category."""
        return [count for count in self.read_counts 
                if self.counts[category][count] < count]
    
    def add_read(self, read, pod5_read, category: str):
        """Add a read to all applicable subsets for the given category."""
        active_counts = self._get_active_counts(category)
        
        if not active_counts:
            return False  # No more subsets need reads for this category
        
        if read.is_secondary or read.is_supplementary:
            print("Supplementary/seconday slipped through...")
            return False  # Don't write non-primary alignments
        
        for count in active_counts:
            if not read.query_name in self.written_reads[category][count]:
                self.written_reads[category][count].append(read.query_name)
                self.pod5_writers[category][count].add_read(pod5_read.to_read())
                self.counts[category][count] += 1
            
            self.bam_writers[category][count].write(read)
            
            # Close writers if target reached
            if self.counts[category][count] >= count:
                self.bam_writers[category][count].close()
                self.pod5_writers[category][count].close()
                self.finished[category][count] = True
                logger.info(f"Completed {category}_{count} with {count} reads")
        
        return True
    
    def is_complete(self) -> bool:
        """Check if all subsets are complete."""
        for category in self.categories:
            if self._get_active_counts(category):
                return False
        return True
    
    def cleanup(self):
        """Close any remaining open file handles."""
        for category in self.categories:
            for count in self.read_counts:
                if not self.finished[category][count]:
                    self.bam_writers[category][count].close()
                    self.pod5_writers[category][count].close()
    
    def get_status(self) -> dict:
        """Get current status of all subsets."""
        status = {}
        for category in self.categories:
            status[category] = {}
            for count in self.read_counts:
                current_count = self.counts[category][count]
                status[category][count] = {
                    'current': current_count,
                    'target': count,
                    'complete': current_count >= count
                }
        return status

def subset_nanopore_data(
    bam_path: str, 
    pod5_path: str, 
    output_dir: str, 
    read_counts: list[int] = [100, 1000, 10000, 100000]
):
    """
    Subset nanopore sequencing data into multiple subsets with different read counts and lengths.
    
    Args:
        bam_path: Path to input BAM file
        pod5_path: Path to input POD5 file or directory containing POD5 files
        output_dir: Output directory for subsets
        read_counts: List of read counts for subsets
    """
    logger.info("Starting nanopore data subsetting")
    
    # Determine read length boundaries
    boundary1, boundary2 = determine_boundary_read_lengths(bam_path)

    pod5_dataset = p5.DatasetReader(pod5_path)
    pod5_reads = list(pod5_dataset.read_ids)

    # Process BAM file
    logger.info("Processing BAM file and creating subsets")
    bam = pysam.AlignmentFile(bam_path, "rb")

    # Initialize subset manager
    subset_manager = SubsetManager(output_dir, read_counts, bam)
    
    
    processed_reads = 0
    skipped_reads = 0
    
    for read in tqdm(bam, desc="Processing reads"):
        processed_reads += 1
        
        # Skip unmapped reads or reads < 100bp
        if (read.is_unmapped or 
            read.is_secondary or 
            read.is_supplementary or 
            read.query_length < 100 or 
            read.query_name not in pod5_reads):
            skipped_reads += 1
            continue
        
        # Get read category
        category = get_read_category(read.query_length, boundary1, boundary2)
        
        # Get corresponding POD5 read
        pod5_read = pod5_dataset.get_read(read.query_name)
        
        # Add read to appropriate subsets
        subset_manager.add_read(read, pod5_read, category)
        
        # Check if all subsets are complete
        if subset_manager.is_complete():
            logger.info("All subsets completed")
            break
        
        # Log progress periodically
        if processed_reads % 10000 == 0:
            status = subset_manager.get_status()
            logger.info(f"Processed {processed_reads} reads, skipped {skipped_reads}")
            for cat in ["short", "medium", "long"]:
                active = [f"{count}({status[cat][count]['current']}/{count})" 
                         for count in read_counts if not status[cat][count]['complete']]
                if active:
                    logger.info(f"  {cat}: {', '.join(active)}")
    
    bam.close()
    subset_manager.cleanup()
    
    # Final status report
    logger.info(f"Subsetting completed. Processed {processed_reads} reads, skipped {skipped_reads}")
    final_status = subset_manager.get_status()
    
    for category in ["short", "medium", "long"]:
        logger.info(f"{category.capitalize()} reads:")
        for count in read_counts:
            current = final_status[category][count]['current']
            target = final_status[category][count]['target']
            status_str = "✓" if current >= target else "✗"
            logger.info(f"  {count}: {current}/{target} {status_str}")


# Example usage
if __name__ == "__main__":
    args = sys.argv

    if len(args) != 5:
        f"Invalid arguments.\n"
        f"Usage: {args[0]} <bam_file> <pod5_dir> <output_dir> <read_counts>\n"
        f"Example:\n"
        f"  {args[0]} my_reads.bam pod5_folder output_folder 100,1000,5000"
        exit(1)

    bam_file = args[1]
    pod5_dir = args[2]
    output_dir = args[3]
    read_counts = [int(i) for i in args[4].split(",")]

    subset_nanopore_data(bam_file, pod5_dir, output_dir, read_counts)