
import argparse
from collections import OrderedDict

def main():
    """
    Main block
    """
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--iids-keep-txt',
                        dest='iids_keep_txt',
                        action='store',default=None, type=str,
                        help='file with list of IIDs to keep '+ \
                             "[default %(default)s]")
    parser.add_argument('--cnv-gene-overlaps-bed-iid-colnum', 
                        dest='cnv_gene_overlaps_bed_iid_colnum',
                        action='store',default=6, type=int,
                        help='column number in cnv gene overlaps BED '+ \
                             "for IID [default %(default)s]")
    parser.add_argument('--clump-sample-overlap-fraction-min', 
                        dest='clump_sample_overlap_fraction_min',
                        action='store',default=0.5, type=float,
                        help='minimum fraction of cnv call sample overlap '+ \
                             "required for clumping of neighboring genes " + \
                             "[default %(default)s]")
    parser.add_argument('--cnv-gene-overlaps-bed-gene-colnum', 
                        dest='cnv_gene_overlaps_bed_gene_colnum',
                        action='store',default=7, type=int,
                        help='column number in cnv gene overlaps BED '+ \
                             "for gene [default %(default)s]")
    parser.add_argument('loci_bed', help='Input loci BED')
    parser.add_argument('cnv_gene_overlaps_bed', help='CNV BED with gene overlaps')
    args = parser.parse_args()

    # if defined by user, get list of iids to keep
    iids_keep = set()
    if args.iids_keep_txt != None:
        in_fh=open(args.iids_keep_txt,"r")
        for line in in_fh:
            iid = line.rstrip()
            iids_keep.add(iid)
        in_fh.close()

    # pass 1 of CNV gene overlaps BED : get all genes hit at least 1x
    genes_carriers_dict = OrderedDict()
    locus_genes_dict = dict()
    in_fh = open(args.cnv_gene_overlaps_bed, "r")
    for line in in_fh:
        data = line.rstrip().split()
        chrom = data[0]
        start0 = data[1]
        end = data[2]
        start = str(int(start0)+1)
        locus = chrom + ":" + start + "-" + end
        iid = data[args.cnv_gene_overlaps_bed_iid_colnum - 1] 
        gene = data[args.cnv_gene_overlaps_bed_gene_colnum - 1]

        # skip if list of iids to keep is defined and iid is not in it
        if len(iids_keep) > 0 and iid not in iids_keep: continue

        if gene not in genes_carriers_dict:
            genes_carriers_dict[gene] = set([])
        genes_carriers_dict[gene].add(iid)
        if locus not in locus_genes_dict:
            locus_genes_dict[locus] = set([])
        locus_genes_dict[locus].add(gene)
    genes_set = set(genes_carriers_dict.keys())
    in_fh.close()

    # focus on genes that are hit by cnvs that hit more than 1 gene
    genes_include = set([])
    for locus in locus_genes_dict.keys():
        n_genes = len(locus_genes_dict[locus])
        if n_genes > 1:
            genes = locus_genes_dict[locus]
            genes_include = genes_include.union(genes)

    # read loci bed to figure out ordering of hit with CNV hitting >1 gene
    genes_dict = OrderedDict()
    in_fh = open(args.loci_bed,'r')
    for line in in_fh:
        data = line.rstrip().split()
        gene = data[3]
        if gene not in genes_include:
            continue
        genes_dict[gene] = None
    in_fh.close()
    genes_assess = list(genes_dict.keys())

    # merge loci by neighboring CNV calls with overlapping carrier iids
    carrier_iids_h = set([])
    gene_clump_set = set([])
    gene_clumps = []
    i=1
    while i < len(genes_assess):
        h = i - 1
        gene_h = genes_assess[h]
        gene_i = genes_assess[i]
        carrier_iids_h = genes_carriers_dict[gene_h]
        carrier_iids_i = genes_carriers_dict[gene_i]
        carrier_iids_h_and_i = carrier_iids_h.intersection(carrier_iids_i)
        carrier_iids_h_or_i = carrier_iids_h.union(carrier_iids_i)
        frac = len(carrier_iids_h_and_i)/float(len(carrier_iids_h_or_i))
        if frac >= args.clump_sample_overlap_fraction_min:
            gene_clump_set.add(gene_h)
            gene_clump_set.add(gene_i)
        elif len(gene_clump_set) > 0:
            gene_clump_list = list(gene_clump_set)
            gene_clump_list.sort()
            gene_clumps.append(",".join(gene_clump_list))
            gene_clump_set = set([])
        carrier_iids_h = carrier_iids_i
        i += 1

    # form translation dictionary
    gene_to_geneclump_dict = dict()
    for gene_clump_str in gene_clumps:
        gene_clump_list = gene_clump_str.split(",")
        for gene in gene_clump_list:
            gene_to_geneclump_dict[gene] = gene_clump_str
    
    # final pass of locus dict: convert gene to geneclump if clump member
    in_fh = open(args.loci_bed, "r")
    for line in in_fh:
        data = line.rstrip().split()
        gene = data[3]
        if gene in gene_to_geneclump_dict:
            geneclump_str = gene_to_geneclump_dict[gene]
        else:
            geneclump_str = gene
        data[3] = geneclump_str
        str_out = "\t".join(data)
        print(str_out)
    in_fh.close()

    return

if __name__ == "__main__":
    main()
