import sys
import argparse
import gzip

# import the IntervalTree and IntervalNode classes from quicksect file
from quicksect import IntervalTree, IntervalNode

def main(userargs):

    # for each PennCNV cnv call, determine if the call is also found in 
    # provided quantiSNP cnv call output. Define the joint call region
    # as the region covered by cnv intervals from both callers. Note that
    # cnvs must be of the same type.
    # get args
    parser = argparse.ArgumentParser(prog='quantisnp_calls_merge',
                                     description='define CNV regions ' + \
                                                 'per sample based on the ' + \
                                                 'intersection of calls between ' + \
                                                 'PennCNV and QuantiSNP.')
    parser.add_argument('--quantisnp-sampleid-colname',
                        type=str, action='store', default='Sample Name',
                        help='name of column with sampleid in quantisnp file')
    parser.add_argument('--quantisnp-chrom-colname',
                        type=str, action='store', default='Chromosome',
                        help='name of column with chromosome in quantisnp file')
    parser.add_argument('--quantisnp-startpos-colname',
                        type=str, action='store', 
                        default='Start Position (bp)',
                        help='name of column with cnv start postion ' + \
                             'in quantisnp file')
    parser.add_argument('--quantisnp-endpos-colname',
                        type=str, action='store', 
                        default='End Position (bp)',
                        help='name of column with cnv end postion ' + \
                             'in quantisnp file')
    parser.add_argument('--quantisnp-startprobe-colname',
                        type=str, action='store', 
                        default='Start Probe ID',
                        help='name of column with cnv start probe ' + \
                             'in quantisnp file')
    parser.add_argument('--quantisnp-endprobe-colname',
                        type=str, action='store', 
                        default='End Probe ID',
                        help='name of column with cnv end probe ' + \
                             'in quantisnp file')
    parser.add_argument('--quantisnp-copynumber-colname',
                        type=str, action='store', 
                        default='Copy Number',
                        help='name of column with cnv copy number ' + \
                             'in quantisnp file')
    parser.add_argument('--exact-cn-match', action='store_true', default=False,
                        help='make intersection require same copynumber, ' + \
                             'rather than just the same copy change ' + \
                             '(ex: calls with cn=3 and cn=4 would not ' + \
                             'intersect, whereas under normal circumstances ' +\
                             'since both are dups it would count as ' + \
                             'intersection.')
    parser.add_argument('--iid-prefix-rm', action='store', type=str, default=None,
                        help='prefix to remove from presumed sample filename, ' + \
                             'to represent sample iid')
    parser.add_argument('--iid-postfix-rm', action='store', type=str, default=None,
                        help='postfix to remove from presumed sample filename, ' + \
                             'to represent sample iid')
    parser.add_argument('snpid_chr_pos_txt', type=str,
                        help='input file with snpid, chrom, pos ' + \
                             'for each marker set used in both ' + \
                             'quantisnp and penncnv')
    parser.add_argument('quantisnp_txt', type=str,
                        help="input quantisnp file.")
    parser.add_argument('penncnv_txt', type=str,
                        help='input penncnv file.')
    args = parser.parse_args(userargs)

    # store snp positions to IntervalTree data structure
    snp_it = snpid_chr_pos_to_intervaltree(args.snpid_chr_pos_txt)

    # store entries from gtf file to IntervalTree data structure
    # init IntervalTree data structure for storing CNV calls from quantisnp
    quantisnp_it = dict()

    # load quantisnp data into intervaltree structs
    quantisnp_fh = open(args.quantisnp_txt, "r")
    i=1
    for line in quantisnp_fh:
        data = line.rstrip().split("\t")
        if i == 1:
            req_cols = [
                        args.quantisnp_sampleid_colname,
                        args.quantisnp_chrom_colname,
                        args.quantisnp_startpos_colname,
                        args.quantisnp_endpos_colname,
                        args.quantisnp_startprobe_colname,
                        args.quantisnp_endprobe_colname,
                        args.quantisnp_copynumber_colname
                       ]
            req_cols_set = set(req_cols)
            data_set=set(data)
            missing_cols=req_cols_set.difference(data_set)
            if len(missing_cols) > 0:
                print("ERROR : missing required cols in quantisnp file : " + \
                      str(",".join(list(missing_cols))))
                sys.exit(1)

            # store columns to idx
            idx=dict()
            for j in range(len(data)):
                idx[data[j]] = j

        else:
            
            # get info for cnv
            sampleid_i = data[idx[args.quantisnp_sampleid_colname]]
            chrom_i = data[idx[args.quantisnp_chrom_colname]]
            startpos_i = data[idx[args.quantisnp_startpos_colname]]
            endpos_i = data[idx[args.quantisnp_endpos_colname]]
            startprobe_i = data[idx[args.quantisnp_startprobe_colname]]
            endprobe_i = data[idx[args.quantisnp_endprobe_colname]]
            copynumber_i = data[idx[args.quantisnp_copynumber_colname]]

            # convert key values to int
            startpos_i = int(startpos_i)
            endpos_i = int(endpos_i)
            copynumber_i = int(copynumber_i)

            # if this is first cnv for sample, init infodict
            if sampleid_i not in quantisnp_it:
                quantisnp_it[sampleid_i] = IntervalTree()

            # init infodict
            infodict_i={
                        "sampleid":sampleid_i,
                        "chrom":chrom_i,
                        "start":startpos_i,
                        "end":endpos_i,
                        "startprobe":startprobe_i,
                        "endprobe":endprobe_i,
                        "copynumber":copynumber_i
                       }

            # if entry is to be kept, insert into intervaltree
            quantisnp_it[sampleid_i].insert(chrom_i, startpos_i, endpos_i,
                                            linenum=i, 
                                            infodict=infodict_i)

        i+=1

    # close quantisnp filehandle
    quantisnp_fh.close()

    # open penncnv filehandle
    penncnv_fh = open(args.penncnv_txt, "r")
    
    # for each line in penncnv file..
    for line in penncnv_fh:
    
        # get data
        data=line.rstrip().split()
        
        # extract items from data
        interval_i=data[0]
        interval_i_x=interval_i.split(":")
        chrom_i = interval_i_x[0].replace("chr","")
        startend_i = interval_i_x[1]
        startend_i_x = startend_i.split("-")
        start_i = int(startend_i_x[0])
        end_i = int(startend_i_x[1])
        nsnp_i=data[1].replace("numsnp=","")
        nsnp_i=int(nsnp_i.replace(",",""))
        length_i=data[2].replace("length=","")
        length_i=int(length_i.replace(",",""))
        state_cn_i=data[3].split(",")
        state_i=state_cn_i[0].replace("state","")
        cn_i=state_cn_i[1].replace("cn=","")
        sampleid_i=data[4]
        if args.iid_prefix_rm != None:
            sampleid_i = sampleid_i.replace(args.iid_prefix_rm, "")
        if args.iid_postfix_rm != None:
            sampleid_i = sampleid_i.replace(args.iid_postfix_rm,"")
        startprobe_i=data[5].replace("startsnp=","")
        endprobe_i=data[6].replace("endsnp=","")

        # get interval overlaps specific to single sample
        overlaps_i = []
        if sampleid_i in quantisnp_it:
            quantisnp_it[sampleid_i].intersect(chrom_i,
                                               start_i,
                                               end_i,
                                               lambda x:overlaps_i.append(x.infodict)
                                              )

        # for each overlap in interval..
        overlap_x=[]
        for overlap_i in overlaps_i:

            # only add to overlaps if copy number is the same
            if args.exact_cn_match == True: 
                if int(cn_i) == overlap_i["copynumber"]: 
                    overlap_x.append(overlap_i)
            else:
                if int(cn_i)<2 and int(overlap_i["copynumber"])<2:
                    overlap_x.append(overlap_i)
                elif int(cn_i)>2 and int(overlap_i["copynumber"])>2:
                    overlap_x.append(overlap_i)

        # skip if no overlap found
        if len(overlap_x) == None: continue
        # if there is an overlap, 
        # for each one..
        for overlap_x_i in overlap_x:
            
            # get info
            qs_start_i = overlap_x_i["start"]
            qs_end_i = overlap_x_i["end"]
            qs_startprobe_i = overlap_x_i["startprobe"]
            qs_endprobe_i = overlap_x_i["endprobe"]

            # 1. get the intersecting start/end positions, probe IDs
            if qs_start_i > start_i:
                intersect_start_i = qs_start_i
                intersect_startprobe_i = qs_startprobe_i
            else:
                intersect_start_i = start_i
                intersect_startprobe_i = startprobe_i
            if qs_end_i < end_i:
                intersect_end_i = qs_end_i
                intersect_endprobe_i = qs_endprobe_i
            else:
                intersect_end_i = end_i
                intersect_endprobe_i = endprobe_i

          # 2. get the number of SNPs in the intersect region
            overlaps_i=[]
            intersect_start_i_0=intersect_start_i-1
            snp_it.intersect(chrom_i,
                             intersect_start_i_0,
                             intersect_end_i,
                             lambda x:overlaps_i.append(x.infodict)
                            )
            overlap_nsnp_i=len(overlaps_i)

            # form the output line (penncnv format)
            interval_out_i="chr"+chrom_i+":"+\
                           str(intersect_start_i)+"-"+\
                           str(intersect_end_i)
            numsnp_out_i="numsnp="+str(overlap_nsnp_i)
            length_out_i="length="+str(intersect_end_i - intersect_start_i + 1)
            state_out_i="state"+state_i
            cn_out_i="cn="+cn_i
            state_cn_out_i=state_out_i+","+cn_out_i
            sampleid_out_i = sampleid_i 
            startsnp_out_i="startsnp="+intersect_startprobe_i
            endsnp_out_i="endsnp="+intersect_endprobe_i
            penncnv_out_i=[interval_out_i, 
                           numsnp_out_i,
                           length_out_i,
                           state_cn_out_i,
                           sampleid_out_i,
                           startsnp_out_i,
                           endsnp_out_i]
            out_str_i=(" ".join(penncnv_out_i))
            print(out_str_i)
    
    # close penncnv filehandle
    penncnv_fh.close()

    return

def snpid_chr_pos_to_intervaltree(snpid_chr_pos_txt):
    snp_it = IntervalTree()
    fh = open(snpid_chr_pos_txt, "r")
    i=0
    for line in fh:
        i+=1
        data = line.rstrip().split()
        snpid_i=data[0]
        chrom_i=data[1]
        try:
            pos_i=int(data[2])
        except:
            continue
        startpos_i = pos_i
        startpos_i_0 = startpos_i - 1
        endpos_i = pos_i
        snp_it.insert(chrom_i, 
                      startpos_i_0, endpos_i,
                      linenum=i, 
                      infodict=None)
    fh.close()
    return snp_it

if __name__ == "__main__":
    userargs = sys.argv[1:]
    main(userargs)

