
import sys
import argparse
import gzip

def main(userargs):

    # for each cnv interval, get gene IDs that cnv overlaps at least 
    # one feature for (be it exon or CDS). Gene IDs can be gene symbol or ENSG.
    # write output as 'chr:start-end GENE1,GENE2,..,GENEn' , 
    # where chr:start-end is the coordinates for a cnv and 
    # GENE1,GENE2,..,GENEn are the distinct gene IDs that have at least one 
    # feature that overlaps with CNV

    # get args
    parser = argparse.ArgumentParser(prog='cnv_gene_scores',
                                     description='using output from ' + \
                                                 'cnv_gtf_annotation ' + \
                                                 'function, make tsv ' \
                                                 'containing stats for gene '
                                                 'scores based on user-' + \
                                                 'provided file.')
    parser.add_argument('--score-name', type=str, default='score',
                        help='name of input score set, for ' + \
                             'column name assignment')
    parser.add_argument('--score-func', type=str, default='max',
                        help='name of function to apply to gene-level scores')
    parser.add_argument('--score-bins', type=str, default=None, 
                        help='score bins to get n genes from, of format ' + \
                             '"val1-val2,val2-val3,val3-val4"')
    parser.add_argument('--score-bin-lower-inclusive', action='store_true',
                        default=False,
                        help='allow values equaling bin lower bound to be ' + \
                             'included in bin.')
    parser.add_argument('--score-bin-upper-inclusive', action='store_true',
                        default=False,
                        help='allow values equaling bin upper bound to be ' + \
                             'included in bin.')
    parser.add_argument('in_cnv_gene_tsv', type=str,
                        help='input file made by cnv_gtf_annotation.')
    parser.add_argument('gene_score_txt', type=str,
                        help='input gene score text file. Column 1 is gene' + \
                             ' symbol, column 2 is gene score')
    args = parser.parse_args(userargs)

    # read gene scores to dictionary. assumed that each line from
    # gene_score_txt is for a unique gene symbol
    gs_score=dict()
    gene_score_fh = open(args.gene_score_txt, "r")
    for line in gene_score_fh:
        data =line.rstrip().split()
        gs_i = data[0]
        try:
            score_i = float(data[1])
        except:
            score_i = None
        if score_i != None:
            gs_score[gs_i] = score_i
    gene_score_fh.close()

    # open filehandle to cnv gene overlap tsv
    in_cnv_gene_fh = open(args.in_cnv_gene_tsv, "r")

    # set column names
    column_names = ['locus', 'n_genes', 'n_genes_'+args.score_name,
                    'genes_' + args.score_name + '_' + args.score_func]
    
    # add columns for score bins if defined by user
    if args.score_bins != None:
        args.score_bins = args.score_bins.split(',')
        for score_bin in args.score_bins:
            column_names.append('n_genes_'+args.score_name+'_'+score_bin)

    # print column header to stdout
    column_names_str = "\t".join(column_names)
    print(column_names_str)

    # iterate through lines in file. get the following :
    # 1. n genes
    # 2. (max, min, mean, median) score
    # 3. n genes (user-defined bin 1) 
    # ..
    # N. n genes (user-defined bin X)
    for line in in_cnv_gene_fh:
        data = line.rstrip().split("\t")
        cnv_locus_i = data[0]
        try:
            gs_overlaps_str_i = data[1]
        except:
            out_list = [cnv_locus_i, "0", "0", "NA"]
            if args.score_bins != None:
                for score_bin_str in args.score_bins:
                    out_list.append("0")
            out_str = "\t".join(out_list)
            print(out_str)
            continue
 
        # convert genesymbols string to list
        gs_overlaps_i = gs_overlaps_str_i.split(",")
        
        # get n gene symbols
        gs_overlaps_n = len(gs_overlaps_i)

        # get scores for set of gene symbols
        gs_scores_i = []
        for gs_j in gs_overlaps_i:
            if gs_j in gs_score:
                gs_scores_i.append(gs_score[gs_j])

        # if no gene symbols have scores, stop here
        if len(gs_scores_i) == 0:
            out_list = [cnv_locus_i, str(gs_overlaps_n), "0", "NA"]
            if args.score_bins != None:
                for score_bin_str in args.score_bins:
                    out_list.append("0")
            out_str = "\t".join(out_list)
            print(out_str)
            continue
        
        # based on user input, get (max, min, mean, median) score
        if args.score_func == "max": gs_scores_func = max(gs_scores_i)
        elif args.score_func == "min": gs_scores_func = min(gs_scores_i)
        elif args.score_func == "mean": gs_scores_func = mean(gs_scores_i)
        elif args.score_func == "median": gs_scores_func = median(gs_scores_i)
        else:
            print("ERROR : score function " + args.score_func + \
                  " not supported.")
            sys.exit(1)
        
        # init list with results for row
        out_list = [cnv_locus_i, 
                    str(gs_overlaps_n),
                    str(len(gs_scores_i)), 
                    str(gs_scores_func)]
 
        # add columns if score bins defined
        if args.score_bins != None:
            n_score_bins = dict()
            for score_bin_str in args.score_bins:
                n_score_bins[score_bin_str] = 0
                score_bin = score_bin_str.split("-")
                score_bin_lower = float(score_bin[0])
                score_bin_upper = float(score_bin[1])

                # for each gene symbol score in full set..
                for gs_score_i in gs_scores_i:
                    
                    # where does score overlap with respect to bin edges?
                    score_l_0 = score_bin_lower < gs_score_i 
                    score_l_1 = score_bin_lower <= gs_score_i
                    score_u_0 = gs_score_i < score_bin_upper
                    score_u_1 = gs_score_i <= score_bin_upper

                    # get overlap based on edge inclusion / exclusion
                    if args.score_bin_lower_inclusive and args.score_bin_upper_inclusive:
                        overlap_j = score_l_0 and score_u_0
                    elif args.score_bin_lower_inclusive:
                        overlap_j = score_l_0 and score_u_l
                    elif args.score_bin_upper_inclusive:
                        overlap_j = score_l_1 and score_u_0
                    else:
                        overlap_j = score_l_1 and score_u_1

                    # if overlap, add to count
                    if overlap_j == True:
                        n_score_bins[score_bin_str] += 1

                # add result to column
                out_list.append(str(n_score_bins[score_bin_str]))
        
        # print output to stdout
        out_str = "\t".join(out_list)
        print(out_str)

    return

def mean(values):
    return sum(values) / float(len(values))

def median(lst):
    # from
    # https://stackoverflow.com/questions/24101524/finding-median-of-list-in-python
    n = len(lst)
    s = sorted(lst)
    return (sum(s[n//2-1:n//2+1])/2.0, s[n//2])[n % 2] if n else None

if __name__ == "__main__":
    userargs = sys.argv[1:]
    main(userargs)
