

import sys
import argparse
import numpy
import pandas
import pyBigWig

def main(userargs):
    # get args
    parser = argparse.ArgumentParser(prog='cnv_gtf_annotation',
                                     description='using a user-defined ' + \
                                                 'GTF file, get gene-level ' +\
                                                 'overlaps for CNVs in ' + \
                                                 'penncnv cnv file.')
    # default="-20:-6,-6:-5,-5:-4,-4:-3,-3:-2,-2:-1,-1:0,0:1,1:2,2:3,3:4,4:5,5:6,6:12",
    parser.add_argument('--score-bins', type=str,
                        default="-20:2.27,2.27:12",
                        help="score bins to place nt counts into. " + 
                             "format: begin1:end1,begin2:end2, ..")
    parser.add_argument('--score-bin-lower-inclusive', action='store_true',
                        default=False,
                        help='allow values equaling bin lower bound to be ' + \
                             'included in bin.')
    parser.add_argument('--score-bin-upper-inclusive', action='store_true',
                        default=False,
                        help='allow values equaling bin upper bound to be ' + \
                             'included in bin.')
    parser.add_argument('--add-metric-cols', action='store_true', default=False,
                        help='add columns for min, max, mean, median and ' + \
                             'sd of scores at locus')
    parser.add_argument('--input-is-region-list', action='store_true', default=False,
                        help='Instead of bed, each line of input file is in format: ' + \
                             'chrom:start-end')
    parser.add_argument('--input-bed-locus-colnum', action='store', type=int,
                        default=None,
                        help='if defined by user, column number in input ' + \
                             'BED file to utilize as locus name.')
    parser.add_argument('in_bigwig', type=str,
                        help='input bigwig file with scores.')
    parser.add_argument('in_bed', type=str,
                        help='input BED file (or region list).')
    parser.add_argument('out_tsv', type=str,
                        help="output tsv with bp counts per score bin across loci.")
    args = parser.parse_args(userargs)

    # init filehandle to bigwig file, ensure that file is bigwig
    bw = pyBigWig.open(args.in_bigwig)
    if bw.isBigWig() == False:
        print("ERROR : bigwig file not bigwig-formated : '" + \
              args.in_bigwig + "'")
        sys.exit(1)

    # init score bins
    score_bins_str = args.score_bins.split(",")
    bin_names = []
    score_bins = dict()
    for score_bin_str in score_bins_str:
        score_bin = score_bin_str.split(":")
        bin_start = score_bin[0]
        bin_end = score_bin[1]
        bin_name = bin_start + "to" + bin_end
        score_bins[bin_name] = (float(bin_start), float(bin_end))
        bin_names.append(bin_name)
    
    # init filehandle to bed file
    if args.in_bed.find(".gz") != -1:
        loci_fh = gzip.open(args.in_bed, "rb")
    else:
        loci_fh = open(args.in_bed, "r")
    
    # for each interval in bed file ..
    # load each interval in input file to memory
    locus_str_list = []
    locus_names=dict()
    for line in loci_fh:
        data = line.rstrip().split("\t")
        if args.input_is_region_list:
            locus = data[0]
            (chrom, start0, end) = locus_to_chromstart0end(locus)
        else:
            chrom = data[0]
            start0 = data[1]
            start = str(int(start0) + 1)
            end = data[2]
            locus = chrom + ":" + start + "-" + end
        locus_str_list.append(locus)
        if args.input_bed_locus_colnum != None:
          locus_names[locus] = data[args.input_bed_locus_colnum - 1]
    
    # unique loci only
    locus_str_set = set(locus_str_list)
    locus_str_list_uniq = list(locus_str_set)
    locus_str_list_uniq.sort()

    # init pandas df
    df = pandas.DataFrame({"locus":locus_str_list_uniq})
    df.index = locus_str_list_uniq

    # if specified by user, switchout locus col
    if args.input_bed_locus_colnum != None:
        locus_list_new = []
        for x in locus_str_list_uniq:
            if x in locus_names:
                y = locus_names[x]
            else:
                y = None
            locus_list_new.append(y)
        df["locus"] = locus_list_new

    # add bin count columns to dataframe
    nrow = len(locus_str_list_uniq)
    for bin_name in bin_names:
        df[bin_name] = numpy.zeros(nrow)

    # if desired by user, add metric columns
    if args.add_metric_cols == True:
        for col in ['nbp','min','max','median','mean','sd']:
            df[col] = numpy.zeros(nrow)

    # for each locus ..
    for locus in df.index:
        
        # get chrom, start0, end
        (chrom, start0, end) = locus_to_chromstart0end(locus)    

        # add 'chr'
        if chrom.find('chr') != 0:
            chrom = "chr" + chrom

        # get values at locus
        if chrom not in bw.chroms():
            continue
        print(chrom,start0,end) 
        values = bw.values(chrom, start0, end, numpy=True)

        # remove nan-values from array
        n_values = len(values)
        values = values[numpy.isnan(values)==False]
        n_non_nan_values = len(values)
        
        # calculate number of values that were nan
        n_nan_values = n_values - n_non_nan_values

        # for each bin
        for bin_name in bin_names:
            score_bin_lower = score_bins[bin_name][0]
            score_bin_upper = score_bins[bin_name][1]
            
            # subset on scores in bin
            if args.score_bin_lower_inclusive and args.score_bin_upper_inclusive:
                vals_x = values[(values >= score_bin_lower) 
                                 &
                                (values <= score_bin_upper)]
            elif args.score_bin_lower_inclusive:
                vals_x = values[(values >= score_bin_lower) 
                                 &
                                (values < score_bin_upper)]
            elif args.score_bin_upper_inclusive:
                vals_x = values[(values > score_bin_lower) 
                                 &
                                (values <= score_bin_upper)]
            else:
                vals_x = values[(values > score_bin_lower) 
                                 &
                                (values < score_bin_upper)]

            # get number of values that fall within bin
            n_vals_x = len(vals_x)
            df.loc[locus, bin_name] += n_vals_x
            
            # if desired by user add summary metrics
            if args.add_metric_cols == True and len(values) > 0:
                df.loc[locus, 'nbp'] = int(end) - int(start0)
                df.loc[locus, "min"] = min(values)
                df.loc[locus, "max"] = max(values)
                df.loc[locus, "median"] = numpy.median(values)
                df.loc[locus, "mean"] = numpy.mean(values)
                df.loc[locus, "sd"] = numpy.std(values)        

    # write pandas df to file
    df.to_csv(path_or_buf=args.out_tsv,
              sep="\t",
              header=True,
              index=False)

    return

def locus_to_chromstart0end(locus):
    chrom_startend = locus.split(":")
    startend = chrom_startend[1].split("-")
    chrom = chrom_startend[0]
    start = int(startend[0])
    start0 = start - 1
    end = int(startend[1])
    return (chrom, start0, end)

if __name__ == "__main__":
    userargs = sys.argv[1:]
    main(userargs)
