

import sys
import argparse
import numpy
import pandas
import pyBigWig

def main(userargs):
    # get args
    parser = argparse.ArgumentParser(prog='cnv_gtf_annotation',
                                     description='using a user-defined ' + \
                                                 'GTF file, get gene-level ' +\
                                                 'overlaps for CNVs in ' + \
                                                 'penncnv cnv file.')
    parser.add_argument('--score-bins', type=str,
                        default="-20:2.27,2.27:12",
                        help="score bins to place nt counts into. " + 
                             "format: begin1:end1,begin2:end2, ..")
    parser.add_argument('--score-bin-lower-inclusive', action='store_true',
                        default=False,
                        help='allow values equaling bin lower bound to be ' + \
                             'included in bin.')
    parser.add_argument('--score-bin-upper-inclusive', action='store_true',
                        default=False,
                        help='allow values equaling bin upper bound to be ' + \
                             'included in bin.')
    parser.add_argument('--add-metric-cols', action='store_true', default=False,
                        help='add columns for min, max, mean, median and ' + \
                             'sd of scores at locus')
    parser.add_argument('--input-is-region-list', action='store_true', default=False,
                        help='Instead of bed, each line of input file is in format: ' + \
                             'chrom:start-end')
    parser.add_argument('--input-bed-locus-colnum', action='store', type=int,
                        default=4,
                        help='if defined by user, column number in input ' + \
                             'BED file to utilize as locus name.')
    parser.add_argument('--col-suffix', type=str, default="",
                        help="string to add to end of columns. " + 
                             "[default %(default)s]")
    parser.add_argument('in_bigwig', type=str,
                        help='input bigwig file with scores.')
    parser.add_argument('score_min', type=int,
                        help='minimum score integer')
    parser.add_argument('score_max', type=int,
                        help='maximum score integer')
    parser.add_argument('in_bed', type=str,
                        help='input BED file (or region list).')
    parser.add_argument('out_tsv', type=str,
                        help="output tsv with bp counts per score bin across loci.")
    args = parser.parse_args(userargs)

    # init filehandle to bigwig file, ensure that file is bigwig
    bw = pyBigWig.open(args.in_bigwig)
    if bw.isBigWig() == False:
        print("ERROR : bigwig file not bigwig-formated : '" + \
              args.in_bigwig + "'")
        sys.exit(1)

    # init score bins
    score_bins_str = args.score_bins.split(",")
    bin_names = []
    score_bins = dict()
    i = args.score_min
    while i < args.score_max:
        j = i + 1
        bin_name = str(i) + "to" + str(j)
        bin_names.append(bin_name)
        score_bins[bin_name] = 0
        i += 1
    
    # init filehandle to bed file
    if args.in_bed.find(".gz") != -1:
        loci_fh = gzip.open(args.in_bed, "rb")
    else:
        loci_fh = open(args.in_bed, "r")
    
    # for each interval in bed file ..
    # load each interval in input file to memory
    locus_str_list = []
    locus_names=dict()
    loci=dict()
    for line in loci_fh:
        data = line.rstrip().split("\t")
        if args.input_is_region_list:
            locus = data[0]
            (chrom, start0, end) = locus_to_chromstart0end(locus)
        else:
            chrom = data[0]
            start0 = data[1]
            start = str(int(start0) + 1)
            end = data[2]
            locus = chrom + ":" + start + "-" + end
            locusname = data[args.input_bed_locus_colnum - 1]
            if locusname not in loci:
                loci[locusname] = set([])
            loci[locusname].add(locus)
    
    # unique loci only
    locus_str_list_uniq = list(loci.keys())
    locus_str_list_uniq.sort()

    # init pandas df
    df = pandas.DataFrame({"locus":locus_str_list_uniq})
    df.index = locus_str_list_uniq

    # if specified by user, switchout locus col
    # if args.input_bed_locus_colnum != None:
    #     locus_list_new = []
    #     for x in locus_str_list_uniq:
    #         if x in locus_names:
    #             y = locus_names[x]
    #         else:
    #             y = None
    #         locus_list_new.append(y)
    #     df["locus"] = locus_list_new

    # add bin count columns to dataframe
    nrow = len(locus_str_list_uniq)
    for bin_name in bin_names:
        df[bin_name + args.col_suffix] = numpy.zeros(nrow, dtype=numpy.int32)

    # if desired by user, add metric columns
    if args.add_metric_cols == True:
        for col in ['nbp','nbp_nan',
                    'min','max','median','mean','sd']:
            df[col + args.col_suffix] = numpy.zeros(nrow, dtype=numpy.int32)

    # for each locus ..
    for locus in df.index:

        # get set of loci inside
        subloci = loci[locus]

        # init counts for nbp, etc
        values_all = numpy.array([])
        nbp = 0
        nbp_nan = 0

        # for each sublocus
        for sublocus in subloci:
        
            # get chrom, start0, end
            (chrom, start0, end) = locus_to_chromstart0end(sublocus)    

            # add 'chr'
            if chrom.find('chr') != 0:
                chrom = "chr" + chrom

            # skip chrom if not  in bigwig
            if chrom not in bw.chroms():
                continue

            # get values at locus
            print(chrom,start0,end)
            values = bw.values(chrom, start0, end, numpy=True)

            # add to nbp total
            nbp_i = end - start0
            nbp += nbp_i

            # remove nan-values from array
            n_values = len(values)
            values = values[numpy.isnan(values)==False]
            n_non_nan_values = len(values)
            
            # calculate number of values that were nan
            n_nan_values = n_values - n_non_nan_values

            # add to total nan values
            nbp_nan += n_nan_values

            # append to values_all
            values_all = numpy.append(values_all, values)

            # convert array from float32 to float64 so you can apply floor func
            # valuesx = values.astype(numpy.float64)
            valuesx = numpy.copy(values)

            # adjust score values to int (floor)     
            values = numpy.floor(valuesx)

            # cast values to dtype int
            values = values.astype(int)

            # if score min is less than 0, adjust min value so there
            # are no negative values
            diff = 0
            if args.score_min < 0:
                diff = 0 - args.score_min
            values0 = values + diff

            # use numpy func to get counts per int in data
            bincount_length=args.score_max-args.score_min
            values_counts = numpy.bincount(values0,
                                           minlength=bincount_length)
            
            # now get nbp per bin
            for i0 in range(len(values_counts)):
                
                # adjust i by diff, derive bin name 
                i = i0 - diff 

                # if value is score_max then skip
                if i == args.score_max:
                    continue

                # if value is 0 then skip
                if values_counts[i0] == 0:
                    continue

                # derive bin name
                j = i + 1
                bin_name = str(i) + "to" + str(j)
                
                # add to dataframe
                df.loc[locus, bin_name+args.col_suffix] += values_counts[i0] 

        # if desired by user add summary metrics
        if args.add_metric_cols == True and len(values_all) > 0:
            df.loc[locus, 'nbp'+args.col_suffix] = nbp
            df.loc[locus, 'nbp_nan'+args.col_suffix] = nbp_nan
            df.loc[locus, "min"+args.col_suffix] = min(values_all)
            df.loc[locus, "max"+args.col_suffix] = max(values_all)
            df.loc[locus, "median"+args.col_suffix] = numpy.median(values_all)
            df.loc[locus, "mean"+args.col_suffix] = numpy.mean(values_all)
            df.loc[locus, "sd"+args.col_suffix] = numpy.std(values_all)        

        # clear from memory
        del(values_all)

    # write pandas df to file
    df.to_csv(path_or_buf=args.out_tsv,
              sep="\t",
              header=True,
              index=False)

    return

def locus_to_chromstart0end(locus):
    chrom_startend = locus.split(":")
    startend = chrom_startend[1].split("-")
    chrom = chrom_startend[0]
    start = int(startend[0])
    start0 = start - 1
    end = int(startend[1])
    return (chrom, start0, end)

if __name__ == "__main__":
    userargs = sys.argv[1:]
    main(userargs)
