
import sys
import argparse
import gzip

# import the IntervalTree and IntervalNode classes from quicksect file
from quicksect import IntervalTree, IntervalNode

def main(userargs):

    # for each cnv interval, get gene IDs that cnv overlaps at least 
    # one feature for (be it exon or CDS). Gene IDs can be gene symbol or ENSG.
    # write output as 'chr:start-end GENE1,GENE2,..,GENEn' , 
    # where chr:start-end is the coordinates for a cnv and 
    # GENE1,GENE2,..,GENEn are the distinct gene IDs that have at least one 
    # feature that overlaps with CNV

    # get args
    parser = argparse.ArgumentParser(prog='cnv_gtf_annotation',
                                     description='using a user-defined ' + \
                                                 'GTF file, get gene-level ' +\
                                                 'overlaps for CNVs in ' + \
                                                 'penncnv cnv file.')
    parser.add_argument('--iid-prefix-rm', action='store', type=str,
                        default=None,
                        help='prefix for iids listed in penncnv output to remove.')
    parser.add_argument('--iid-postfix-rm', action='store', type=str,
                        default=None,
                        help='postfix for iids listed in penncnv output to remove.')
    parser.add_argument('--feature-classifs', action='store', type=str,
                        default=None,
                        help='subset of features in GTF to intersect on.')
    parser.add_argument('--feature-pos-5prime-adjust', action='store', type=int,
                        default=0,
                        help='amount of bases to adjust input feature ' + \
                             'coord by at the 5-prime end (strand sensitive).')
    parser.add_argument('--feature-pos-3prime-adjust', action='store', type=int,
                        default=0,
                        help='amount of bases to adjust input feature ' + \
                             'coord by at the 3-prime end (strand sensitive).')
    parser.add_argument('--genesymbol-keep-listfile', action='store',
                        type=str, default=None,
                        help='file containing list of gene symbols to ' + \
                             'subset on')
    parser.add_argument('--ensg-keep-listfile', action='store',
                        type=str, default=None,
                        help='file containing list of ENSGs to ' + \
                             'subset on')
    parser.add_argument('--attributes-is-in', action='store', 
                        type=str, default=None,
                        help='string containing a set of attribute to ' + \
                             'subset on, with the following format : ' + \
                             '"atribute1:X,Y;attribute2:A,B,C"')
    parser.add_argument('--tags-required', action='store', type=str, 
                        default=None,
                        help='comma-delimited set of tags required for ' + \
                             'gtf interval inclusion.')
    parser.add_argument('in_gtf', type=str,
                        help='input gtf or gtf.gz file.')
    parser.add_argument('out_bed', type=str,
                        help='output BED file.')
    args = parser.parse_args(userargs)

    # convert particular userargs from string to set
    if args.feature_classifs != None:
        args.feature_classifs = set(args.feature_classifs.split(","))
    if args.tags_required != None:
        args.tags_required = set(args.tags_required.split(','))

    # if specific attribute subsetting defined, load from 
    # the user-provided string
    attributes_is_in = dict()
    if args.attributes_is_in != None:
        attr_is_in_list = args.attributes_is_in.split(";")
        for attr_keyvals_str in attr_is_in_list:
          attr_keyvals = attr_keyvals_str.split(":")
          attr_key = attr_keyvals[0]
          attr_vals = set(attr_keyvals[1].split(","))
          attributes_is_in[attr_key] = attr_vals

    # init filehandle to input gtf file
    if args.in_gtf.find(".gz") != -1:
        gtf_fh = gzip.open(args.in_gtf, "rb")
    else:
        gtf_fh = open(args.in_gtf, "r")

    # init filehandle to output region bed
    out_rgn_bed_fh = None
    if args.out_rgn_bed != None:
        out_rgn_bed_fh = open(args.out_rgn_bed, "w")

    # store entries from gtf file to IntervalTree data structure
    gtf_intervaltree = IntervalTree()
    i = 0
    n_intervals_stored=0
    for line_i in gtf_fh:
        # if input is gzipped, decode to utf-8
        if args.in_gtf.find(".gz") != -1:
            line_i = line_i.decode('utf8')
        i += 1
        if line_i[0] == "#": continue
        data_i = line_i.rstrip().split("\t")

        # it is assumed that gtf/gff will have 9 columns. if not,
        # then raise error and halt run.
        if len(data_i) != 9:
            print("ERROR : line " + str(i) + \
                  " does not have 9 columns.")

        # get values from gtf/gff row
        seqname_i=data_i[0]
        source_i=data_i[1]
        feature_i=data_i[2]
        start_i=data_i[3]
        end_i= data_i[4]
        score_i=data_i[5]
        strand_i=data_i[6]
        frame_i=data_i[7]
        attributes_i=data_i[8]

        # make sure numeric items are properly converted
        # and adjust start and end positioned by user-defined increments
        start_i = int(start_i)
        end_i = int(end_i)

        # adjust 5prime and 3prime ends by user-defined increments.
        # the 5prime and 3prime ends are determined via reported strand (+/-_
        if strand_i == "-":
            start_i = start_i + args.feature_pos_3prime_adjust
            end_i = end_i + args.feature_pos_5prime_adjust
        else:
            start_i = start_i + args.feature_pos_5prime_adjust 
            end_i = end_i + args.feature_pos_3prime_adjust

        # convert attributes to dict.
        # read tags into a seperate set
        attributes_dict_i = dict()
        tags_i = set()
        attributes_i = attributes_i.replace("; ", ";")
        attributes_list_i = attributes_i.split(";")
        for attr_keyval_str in attributes_list_i:
            if attr_keyval_str == "": continue
            attr_keyval_str = attr_keyval_str.replace('"','')
            attr_keyval = attr_keyval_str.split(" ")
            attributes_dict_i[attr_keyval[0]] = attr_keyval[1]
            if attr_keyval[0] == "tag":
                tags_i.add(attr_keyval[1])

        # if a feature classif(s) is defined, then skip line if feature
        # does not match it
        if args.feature_classifs != None:
            if feature_i not in args.feature_classifs: continue

        # for each user-specified attribute, only keep a line if it contains
        # a value that is allowed for that attribute
        attr_fail = False
        for attr_key in attributes_is_in:
            if attr_key not in attributes_dict_i.keys(): 
                attr_fail = True
            elif attributes_dict_i[attr_key] not in attributes_is_in[attr_key]:
                attr_fail = True
            elif args.tags_required != None:
                if len(tags_i.intersection(args.tags_required)) != len(args.tags_required):
                    attr_fail = True
        if attr_fail == True:
            continue
        
        # if output region bed filehandle created, write locus to it
        if out_rgn_bed_fh != None:
            out_str = "\t".join([str(seqname_i),
                                 str(int(start_i)-1),
                                 str(end_i),
                                 attributes_dict_i["gene_name"]])
            out_rgn_bed_fh.write(out_str + "\n")

        # if entry is to be kept, insert into intervaltree
        gtf_intervaltree.insert(seqname_i, start_i, end_i,
                                linenum=i, 
                                infodict=attributes_dict_i)
        n_intervals_stored+=1
        #if n_intervals_stored == 100000: 
        #    print("n intervals stored : " + str(n_intervals_stored))
        #    print("n lines parsed : " + str(i))
        #    sys.exit(1)

    # close gtf filehandle
    gtf_fh.close()
    print("N intervals stored : " + str(n_intervals_stored))
    print("N lines parsed : " + str(i))

    # if output rgn bed filehandle was open, close it 
    if out_rgn_bed_fh != None:
        out_rgn_bed_fh.close()

    # init filehandle to output tsv
    out_fh = open(args.out_tsv, "w")

    # init filehandle to penncnv text file. assumed that column1 is
    # chr:start-end 
    cnv_fh = open(args.cnv_bed, "r")
    i=0
    for line in cnv_fh:
       
        # line counter
        i += 1

        # assume that data are whitespace - delimited
        # and that interval is the first column
        data = line.rstrip().split()
        chrom_i = data[0].replace("chr","")
        start_i = int(data[1])
        end_i = int(data[2])
        interval_i = data[3]

        # get interval overlaps
        overlaps_i = set()
        gtf_intervaltree.intersect(chrom_i,
                                   start_i,
                                   end_i,
                                   lambda x:overlaps_i.add(x.infodict["gene_name"])
                                  )
        overlaps_i = list(overlaps_i)
        overlaps_i.sort()
        overlaps_str_i = ",".join(overlaps_i)
        out_str_i = interval_i + "\t" + overlaps_str_i
        out_fh.write(out_str_i + "\n")

    # close filehandles
    cnv_fh.close()
    out_fh.close()

    return

if __name__ == "__main__":
    userargs = sys.argv[1:]
    main(userargs)
