#!/usr/bin/env python3
#
# Copyright (c) 2022-Present Matt Halvorsen <mhalvors1@gmail.com> 
# Distributed under terms of the MIT license.


"""
Run case/control association tests for CNV burden across multiple pre-defined
case/control groups, and report meta-analysis results. Also produce QQ plots and
produce an estimate for genomic inflation, with the understanding that large
CNVs overlapping multiple genes can inflate these statistics, particularly if
they are preferentially found in cases.
"""


import argparse
import sys
import pybedtools
from math import log10
import random
import numpy as np
import pandas
from statsmodels.stats.contingency_tables import StratifiedTable as cmh
from collections import OrderedDict

def main():
    """
    Main block
    """
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--cnv-types', type=str, default="DEL,DUP",
                        help='Comma-delimited list of CNV types to perform '+\
                             'assoc tests on. Choices : DEL, DUP, DELDUP.'+\
                             "[default %(default)s]")
    parser.add_argument('--chr', help='Restrict to a subset of chromosomes. ' +
                        "Specify as comma-separated list [default %(default)s]")
    parser.add_argument('--minsize', type=int, help='Minimum CNV size. ' +
                        '[default %(default)s]', default=30000)
    parser.add_argument('--maxsize', type=int, help='Maximum CNV size. ' +
                        '[default %(default)s]', default=20000000)
    parser.add_argument('--random-seed-number',
                        dest='random_seed_number',
                        type=int, default=None,
                        help="random seed number to set, for reproducability "+\
                             "[default %(default)s]")
    parser.add_argument('--gene-bed', dest='gene_bed', type=str, default=None,
                        help='input BED file with 4 columns, column 4 being ' +\
                             'the name of the gene/test unit ' +\
                             "[default %(default)s].")
    parser.add_argument('--n-gene-overlap-min',
                        dest='n_gene_overlap_min',
                        type=int, default=None,
                        help="min number of gene overlaps allowed for a " + \
                             "cnv to be allowed in analysis [default %(default)s]")
    parser.add_argument('--n-gene-overlap-max',
                        dest='n_gene_overlap_max',
                        type=int, default=None,
                        help="max number of gene overlaps allowed for a " + \
                             "cnv to be allowed in analysis [default %(default)s]")
    parser.add_argument('--n-perm',
                        dest='n_perm',
                        type=int, default=100,
                        help="number of permutations [default %(default)s]")
    parser.add_argument('--use-cmh-stats-cache',
                        dest='use_cmh_stats_cache',
                        default=False, action='store_true',
                        help="use cmh stats cache for faster permutations " +\
                             "[default %(default)s]")
    parser.add_argument('--sampleinfo-group-grep',
                        dest='sampleinfo_group_grep',
                        type=str, default="full",
                        help='if not "full", only keep iids where group ' + \
                             'string contains this substring ' + \
                             "[default %(default)s].")
    parser.add_argument('--sampleinfo-group-cols-paste',
                        dest='sampleinfo_group_cols_paste',
                        type=str, default="group,clusters",
                        help='comma-delim columns in sampleinfo tsv to paste together ' + \
                             'in order to form case/control GROUP column ' + \
                             "[default %(default)s].")
    parser.add_argument('--sampleinfo-iid-col',
                        dest='sampleinfo_iid_col',
                        type=str, default="IID",
                        help='name of column in sampleinfo tsv indicating ' + \
                             'sample individual ID [default %(default)s]')
    parser.add_argument('--sampleinfo-casestatus-col',
                        dest='sampleinfo_casestatus_col',
                        type=str, default="CASE",
                        help='name of column in sampleinfo tsv indicating ' + \
                             '0/1 case status [default %(default)s]')
    parser.add_argument('--cnv-qq-plot-pdf', dest='cnv_qq_plot_pdf',
                        type=str, default=None,
                        help='name of PDF file to create with QQ plot ' + \
                             'with both deletions and duplications ' + \
                             "[default %(default)s]")
    parser.add_argument('--del-qq-plot-pdf', dest='del_qq_plot_pdf',
                        type=str, default=None,
                        help='name of PDF file to create with QQ plot ' + \
                             'with both deletions ' + \
                             "[default %(default)s]")
    parser.add_argument('--dup-qq-plot-pdf', dest='dup_qq_plot_pdf',
                        type=str, default=None,
                        help='name of PDF file to create with QQ plot ' + \
                             'with duplications ' + \
                             "[default %(default)s]")
    parser.add_argument('--qq-risk-cnv', dest='qq_risk_cnv',
                        action='store_true',default=False,
                        help='make QQ plot and calculate lambda based only ' + \
                             'on loci with odds ratio > 1 ' + \
                             "[default %(default)s]")
    parser.add_argument('--qq-protective-cnv', dest='qq_protective_cnv',
                        action='store_true',default=False,
                        help='make QQ plot and calculate lambda based only ' + \
                             'on loci with odds ratio < 1 ' + \
                             "[default %(default)s]")
    parser.add_argument('--matplotlib-use', type=str, default=None,
                        help='non-default engine for matplotlib to use.')
    parser.add_argument('in_iid_group_pheno_tsv', 
                        help='Input TSV where column1 is individual ID, ' + \
                             'column2 is case/control group and column3 is ' + \
                             'phenotype (1=case, 0=control)')
    parser.add_argument('in_bed', help='Input BED (supports "stdin").')
    parser.add_argument('out_tsv', help='name of output tsv with results')
    args = parser.parse_args()

    # make sure that matplotlib is correctly setup
    if args.matplotlib_use != None:
        import matplotlib
        matplotlib.use(args.matplotlib_use)

    # if one is defined by user, set random set number
    if args.random_seed_number != None:
        random.seed(args.random_seed_number)

    # init structs for storing info per group
    iids = set([])
    groups_set = set([])
    groups_iids = dict()
    groups_case_iids = dict()
    groups_ctrl_iids = dict()
    groups_casestatus = dict()
    groups_n = dict()
    groups_nca = dict()

    # init filehandle for reading sample data
    in_fh = open(args.in_iid_group_pheno_tsv, "r")

    # for each line, extract iid, group and pheno and store in structs
    for line in in_fh:
        data = line.rstrip().split()
        iid_i = data[0]
        group_i = data[1]
        casestatus_i = int(data[2])

        # if group_grep not set as 'full', only proceed if match is found
        if args.sampleinfo_group_grep != "full":
            if group_i.find(args.sampleinfo_group_grep) == -1:
                continue

        # add to full set of iids
        iids.add(iid_i)

        # if group not in structs, init value
        if group_i not in groups_set:
            groups_set.add(group_i)
            groups_iids[group_i] = []
            groups_case_iids[group_i] = set([]) 
            groups_ctrl_iids[group_i] = set([]) 
            groups_casestatus[group_i] = []
            groups_n[group_i] = 0
            groups_nca[group_i] = 0

        # add entry to structs
        groups_iids[group_i].append(iid_i)
        groups_casestatus[group_i].append(casestatus_i)
        groups_n[group_i] += 1
        if casestatus_i == 1:
            groups_nca[group_i] += 1
            groups_case_iids[group_i].add(iid_i)
        elif casestatus_i == 0:
            groups_ctrl_iids[group_i].add(iid_i)
        else:
            print("ERROR : values in pheno col can only be " + \
                  "'0' (control) or '1' (case).")
            sys.exit(1)

    # close filehandle for reading sample data
    in_fh.close()

    # get and sort all unique groups
    groups = list(groups_set)
    groups.sort()

    # get cnv types to focus analyses on
    cnv_types = args.cnv_types.split(",")

    # get distinct breakpoints from bed file
    # if gene bed file provided, init pbt instance from it.
    # otherwise assume breakpoint-based test and load breakpoints to pbt
    # instance.
    if args.gene_bed != None:
        locus_pbt = bed_file_to_gene_pbt(args.gene_bed,
                                         cnv_types = cnv_types)
    else:
        add_deldup = False
        if "DELDUP" in cnv_types:
            add_deldup = True
        locus_pbt = bed_file_to_breakpoints_pbt(args.in_bed, 
                                                iids_keep=iids,
                                                add_deldup=add_deldup,
                                                breakpoint_delim="_")

    # Read input bed
    if args.in_bed in 'stdin -'.split():
        cnv_pbt = pybedtools.BedTool(fn=sys.stdin)
    else:
        cnv_pbt = pybedtools.BedTool(fn=args.in_bed)

    # save local copies of pybedtools objects so generator doesn't get
    # overwritten
    locus_pbt.saveas()
    cnv_pbt.saveas()

    # Restrict to a subset of chromosomes, or autosomes if not specified
    if args.chr is None:
        chroms = [i for subl in ['{0} chr{1}'.format(c, c).split() for c in range(1, 23)] for i in subl]
    else:
        chroms = args.chr.split(',')
    cnv_pbt = cnv_pbt.filter(lambda x: x.chrom in chroms)

    # # Loose restriction on CNV minimum size prior to self-intersect
    # # (It is impossible to attain target RO with CNVs smaller than 
    # #  args.recipoverlap * args.minsize)
    # cnvs = cnvs.filter(lambda x: x.stop - x.start >= (args.recipoverlap * args.minsize) )

    # Restrict on minimum size
    if args.minsize != None:
        cnv_pbt = cnv_pbt.filter(lambda x: len(x) >= args.minsize)

    # Restrict on maximum size
    if args.maxsize != None:
        cnv_pbt = cnv_pbt.filter(lambda x: len(x) <= args.maxsize)

    # intersect cnvs with breakpoints (or gene loci if defined by user)
    xcnvs = locus_pbt.intersect(cnv_pbt,
                                wa=True,
                                wb=True)
    
    # if gene bed provided, parse cnv / gene region intersects.
    # otherwise, parse cnv / breakpoint intersects.
    if args.gene_bed != None:
        # intervals_multigene_cnvs = get_multigene_cnvs(locus_pbt, cnv_pbt,
        #                                               n_gene_min=2)
        # xcnvs = locus_pbt.intersect(cnv_pbt,
        #                             wa=True,
        #                             wb=True)
        xcnvs_carriers = get_locus_cnv_carriers(xcnvs, 
                                                iids_keep=iids,
                                                n_gene_overlap_min=args.n_gene_overlap_min,
                                                n_gene_overlap_max=args.n_gene_overlap_max,
                                                is_breakpoints=False)
    else:
        xcnvs_carriers = get_locus_cnv_carriers(xcnvs, 
                                                iids_keep=iids,
                                                is_breakpoints=True)

    # retrieve all locus names from ordered dict of cnv carriers
    locus_names = list(xcnvs_carriers.keys())
    
    # init dictionary for storing results (turn into pandas dataframe later)
    res_dict = OrderedDict()
    res_dict_keys = ["locus_name", "nca_1_tot", "nco_1_tot", "nca_0_tot",
                     "nco_0_tot", "breslowday_p", "cmh_or", "cmh_95ci_l",
                     "cmh_95ci_u","cmh_p"]
    for x in res_dict_keys:
        res_dict[x] = []

    # init pandas df for storing permutation stats
    pvals_df = pandas.DataFrame(locus_names, columns=['locus_name'])

    # before proceeding, init data structs for storing cmh sumstats
    cmh_stats_cache = dict()

    # first perform tests for observed case/control phenotypes
    for locus_name in locus_names:
        carrier_iids = xcnvs_carriers[locus_name]
        cmh_input = build_cmh_input(carrier_iids, groups, 
                                    groups_case_iids, groups_ctrl_iids)
        nca_1_tot = cmh_input['nca_1_tot']
        nco_1_tot = cmh_input['nco_1_tot']
        nca_0_tot = cmh_input['nca_0_tot'] 
        nco_0_tot = cmh_input['nco_0_tot']

        # if value previously computed, then pull pre-computed value from 
        # cmh stats cache. Otherwise compute the stats and store to cache.
        res = None
        if args.use_cmh_stats_cache == True:
            if cmh_input['counts_str'] in cmh_stats_cache:
                res = cmh_stats_cache[ cmh_input['counts_str'] ]
            else:
                res = cmh_test( cmh_input['tbls'] )
                cmh_stats_cache[ cmh_input['counts_str'] ] = res
        else:
            res = cmh_test( cmh_input['tbls'] )
        (cmh_or, cmh_95ci_l, cmh_95ci_u, cmh_p,
         breslowday_stat, breslowday_p) = res
        
        for x in res_dict_keys:
            res_dict[x].append(eval(x))


    # convert results to pandas df
    res_df = pandas.DataFrame(res_dict)
    
    # now time for derivation of permutation-based pvalues

    # init dataframes for storing these results
    exp_pvals = np.zeros(len(locus_names))
    # perm_pvals_df = pandas.DataFrame(locus_names, columns=['locus_name'])
    # perm_ors_df = pandas.DataFrame(locus_names, columns=['locus_name'])
    
    # for each permutation:
    for x0 in range(args.n_perm):

        # define column index for writing p-values
        x = x0 + 1

        # print permutation number to stdout
        print("Permutation number : " + str(x))

        # define structs to store permuted data
        groups_case_iids = dict()
        groups_ctrl_iids = dict()

        # for each input group..
        case_iids_perm = []
        for group_i in groups:
            iids_i = groups_iids[group_i] 
            n_i = groups_n[group_i]
            nca_i = groups_nca[group_i]
            
            # shuffle the iids. first n_ca iids are cases, rest are controls
            iids_perm_i = random.sample(iids_i, n_i)
            case_iids_perm_i = iids_perm_i[:nca_i]
            ctrl_iids_perm_i = iids_perm_i[nca_i:]
            groups_case_iids[group_i] = set(case_iids_perm_i)
            groups_ctrl_iids[group_i] = set(ctrl_iids_perm_i)

        # now for each locus ..
        cmh_ors_perm = []
        cmh_pvals_perm = []
        for locus_name in locus_names:
            carrier_iids = xcnvs_carriers[locus_name]
            # generate cmh input using shuffled case/control phenotypes,
            # stratified by defined groups
            cmh_input = build_cmh_input(carrier_iids, groups, 
                                        groups_case_iids, groups_ctrl_iids)
            nca_1_tot = cmh_input['nca_1_tot']
            nco_1_tot = cmh_input['nco_1_tot']
            nca_0_tot = cmh_input['nca_0_tot'] 
            nco_0_tot = cmh_input['nco_0_tot']

            # if stats previously computed, then pull pre-computed value from 
            # cmh stats cache. Otherwise compute the stats and store to cache.
            res = None
            if args.use_cmh_stats_cache == True:
                if cmh_input['counts_str'] in cmh_stats_cache:
                    res = cmh_stats_cache[ cmh_input['counts_str'] ]
                else:
                    res = cmh_test( cmh_input['tbls'] )
                    cmh_stats_cache[ cmh_input['counts_str'] ] = res
            else:
                res = cmh_test( cmh_input['tbls'] )
            (cmh_or, cmh_95ci_l, cmh_95ci_u, cmh_p,
             breslowday_stat, breslowday_p) = res
            
            # add permutation stats to lists
            cmh_pvals_perm.append(cmh_p)

        # sort p-values by increasing value
        cmh_pvals_perm.sort()

        # add permuted p-values to pandas df of permuted p-values 
        # perm_pvals_df.loc[:,x] = cmh_pvals_perm

        # add permuted p-values to array of p-value sums
        exp_pvals += np.array(cmh_pvals_perm)

    # calculate expected p-value distribution
    # exp_pvals = list(perm_pvals_df.mean(axis=1))
    exp_pvals = exp_pvals / args.n_perm

    # before writing results to file, sort by cmh_p and add exp p dist
    res_df = res_df.sort_values('cmh_p')
    res_df['cmh_p_expected'] = exp_pvals

    # write results to file
    res_df.to_csv(path_or_buf=args.out_tsv,
                  sep="\t",
                  na_rep="NA",
                  header=True,
                  index=False)

    # for QQ plot + lambda calculations, if desired by user, subset on 
    # risk or protective loci
    if args.qq_risk_cnv == True:
        res_df = res_df[res_df.cmh_or > 1]
    elif args.qq_protective_cnv == True:
        res_df = res_df[res_df.cmh_or < 1]

    # TEST
    # res_df = res_df[res_df.cmh_p_expected > 0.1]

    # qq plot (all cnvs)
    if args.cnv_qq_plot_pdf != None:

        # make qq plot and derive lambda
        qq(res_df.cmh_p, res_df.cmh_p_expected, args.n_perm, 
           args.cnv_qq_plot_pdf, no_qq_logfile=False,
           top_n_exclude=None)

    # qq plot (dels)
    if args.del_qq_plot_pdf != None:

        # derive expected and observed p-value dist (deletions only)
        is_del = res_df.locus_name.str.contains("_DEL")
        res_del_df = res_df[is_del]

        # make qq plot and derive lambda
        qq(res_del_df.cmh_p, res_del_df.cmh_p_expected, args.n_perm, 
           args.del_qq_plot_pdf, no_qq_logfile=False)

    # qq plot (dups)
    if args.dup_qq_plot_pdf != None:

        # derive expected and observed p-value dist (duplications only)
        is_dup = res_df.locus_name.str.contains("_DUP")
        res_dup_df = res_df[is_dup]

        # make qq plot and derive lambda
        qq(res_dup_df.cmh_p, res_dup_df.cmh_p_expected, args.n_perm, 
           args.dup_qq_plot_pdf, no_qq_logfile=False)

    return

def build_cmh_input(carriers_iids, groups, 
                    groups_case_iids, groups_ctrl_iids,
                    counts_per_group=False):
    counts_list = []
    tbls = []
    nca_1_tot = 0
    nco_1_tot = 0
    nca_0_tot = 0
    nco_0_tot = 0
    for group_i in groups:
        case_carriers_i = carriers_iids.intersection(groups_case_iids[group_i])
        ctrl_carriers_i = carriers_iids.intersection(groups_ctrl_iids[group_i])
        nca_1_i = len(case_carriers_i)
        nco_1_i = len(ctrl_carriers_i)
        nca_0_i = len(groups_case_iids[group_i]) - nca_1_i
        nco_0_i = len(groups_ctrl_iids[group_i]) - nco_1_i 
        nca_1_tot += nca_1_i
        nco_1_tot += nco_1_i
        nca_0_tot += nca_0_i
        nco_0_tot += nco_0_i
        tbl = [[nca_1_i, nco_1_i], [nca_0_i, nco_0_i]]
        counts_list.extend([str(nca_1_i), str(nco_1_i)])
        tbls.append(tbl)
    counts_str = "_".join(counts_list)
    cmh_input = {"tbls":tbls,
                 "nca_1_tot":nca_1_tot,
                 "nco_1_tot":nco_1_tot,
                 "nca_0_tot":nca_0_tot,
                 "nco_0_tot":nco_0_tot,
                 "counts_str":counts_str}
    return cmh_input

def cmh_test(tbls):
    
    # form cmh test stats using input count table
    np_tbls = np.dstack(tbls).astype(np.int)
    np_tbls_1 = np_tbls + 1
    cmh_res = cmh(np_tbls)
    cmh_res_1 = cmh(np_tbls_1)
    res = cmh_res.test_null_odds(True)

    # add 1 to all cells of input count table to be able to perform
    # breslow-day test for homogeneity of odds ratios across data
    np_tbls_1 = np_tbls + 1
    cmh_res_1 = cmh(np_tbls_1)
    breslowday_res = cmh_res_1.test_equal_odds(True)

    # retrieve cmh and breslow day test stats
    cmh_or = cmh_res.oddsratio_pooled 
    cmh_ci = cmh_res.oddsratio_pooled_confint()
    cmh_95ci_l = cmh_ci[0]
    cmh_95ci_u = cmh_ci[1]
    cmh_p = res.pvalue
    breslowday_statistic = breslowday_res.statistic
    breslowday_pvalue = breslowday_res.pvalue
    results = (cmh_or, cmh_95ci_l, cmh_95ci_u, cmh_p,
               breslowday_statistic, breslowday_pvalue) 
    return results

def bed_file_to_gene_pbt(bed_filename, delim="_", 
                         cnv_types=["DEL","DUP"]):
    if bed_filename == "-" or bed_filename == "stdin":
        in_fh = sys.stdin
    else:
        in_fh = open(bed_filename, "r")
    bed_list = []
    for line in in_fh:
        data = line.rstrip().split("\t")
        chrom = data[0]
        start0 = int(data[1])
        end = int(data[2])
        start = start0 + 1
        end0 = end - 1
        locusname = data[3]
        for cnvtype in cnv_types:
            locusname_i = locusname + "_" + cnvtype
            bed_list.append([chrom, start0, end, locusname_i])
    
    # cloe filehandle to bed
    in_fh.close()
    
    # sort bed list before making pbt instance

    # init pbt instance
    pbt = pybedtools.BedTool(fn=bed_list)
 
    return pbt


def bed_file_to_breakpoints_pbt(bed_filename, iids_keep=set([]),
                                breakpoint_delim="_", add_deldup=False):
    if bed_filename == "-" or bed_filename == "stdin":
        in_fh = sys.stdin
    else:
        in_fh = open(bed_filename, "r")
    bed_list = []
    for line in in_fh:
        data = line.rstrip().split("\t")
        chrom = data[0]
        start0 = int(data[1])
        end = int(data[2])
        start = start0 + 1
        end0 = end - 1
        cnvtype = data[4]
        iid = data[5]
        if len(iids_keep) != 0 and iid not in iids_keep:
            continue
        breakpoint_1_name = "_".join([chrom, str(start), cnvtype])
        breakpoint_2_name = "_".join([chrom, str(end), cnvtype])
        bed_list.append([chrom, start0, start, breakpoint_1_name])
        bed_list.append([chrom, end0, end, breakpoint_2_name])
        if add_deldup == True:
            breakpoint_1_name = "_".join([chrom, str(start), "DELDUP"])
            breakpoint_2_name = "_".join([chrom, str(end), "DELDUP"])
            bed_list.append([chrom, start0, start, breakpoint_1_name])
            bed_list.append([chrom, end0, end, breakpoint_2_name])
            
    in_fh.close()

    # sort bed list before making pbt instance

    # init pbt instance
    pbt = pybedtools.BedTool(fn=bed_list)
 
    return pbt

def get_locus_cnv_carriers(loci_cnvs_intersect, 
                           iids_keep=set([]),
                           n_gene_overlap_min=None,
                           n_gene_overlap_max=None,
                           is_breakpoints=True):
    # n overlaps per locus
    overlaps_per_locus = dict()
    # locus -> interval -> iid
    locus_cnv_carriers_dict_x = OrderedDict()
    # locus -> iid
    locus_cnv_carriers_dict = OrderedDict()
    for xfeature in loci_cnvs_intersect:
        locus_name = xfeature[3]
        cnv_chrom = xfeature[4]
        cnv_start0 = xfeature[5]
        cnv_end = xfeature[6]
        cnv_locus = xfeature[7]
        cnv_type = xfeature[8]
        cnv_iid = xfeature[9]
        if len(iids_keep) != 0 and cnv_iid not in iids_keep:
            continue
        if is_breakpoints == False:
            locus_info = locus_name.split("_")
            locus_x_name = locus_info[0]
            locus_x_cnvtype = locus_info[1]
            
            # skip entry if cnvtype from locus in question doesn't match cnv
            if locus_x_cnvtype != cnv_type and locus_x_cnvtype != "DELDUP":
                continue

            # keep track of overlaps per cnv locus
            if cnv_locus not in overlaps_per_locus:
                overlaps_per_locus[cnv_locus] = set([])
            overlaps_per_locus[cnv_locus].add(locus_x_name)

        else:
            breakpoint_info = locus_name.split("_")
            breakpoint_chrom = breakpoint_info[0]
            breakpoint_pos = breakpoint_info[1]
            breakpoint_cnvtype = breakpoint_info[2]
            
            # skip entry if cnvtype from breakpoint parent cnv doesn't match cnv
            if breakpoint_cnvtype != cnv_type and breakpoint_cnvtype != "DELDUP":
                continue
        
        # if not yet defined, add locus to carriers dict
        if locus_name not in locus_cnv_carriers_dict_x:
            locus_cnv_carriers_dict_x[locus_name] = dict()
        if cnv_locus not in locus_cnv_carriers_dict_x[locus_name]:
            locus_cnv_carriers_dict_x[locus_name][cnv_locus] = set([])

        # add cnv carrier iid to dict of carriers for locus
        locus_cnv_carriers_dict_x[locus_name][cnv_locus].add(cnv_iid)
    
    # collapse down to dictionary of locus -> cnv_carriers
    for locus_name in locus_cnv_carriers_dict_x.keys():
        for cnv_locus in locus_cnv_carriers_dict_x[locus_name].keys():
            
            # skip if overlapping too many loci
            if is_breakpoints == False and n_gene_overlap_max != None:
                if len(overlaps_per_locus[cnv_locus]) > n_gene_overlap_max:
                    continue
            elif is_breakpoints == False and n_gene_overlap_min != None:
                if len(overlaps_per_locus[cnv_locus]) < n_gene_overlap_min: 
                    continue
            
            if locus_name not in locus_cnv_carriers_dict:
                locus_cnv_carriers_dict[locus_name]=set([])

            for iid in locus_cnv_carriers_dict_x[locus_name][cnv_locus]:
                locus_cnv_carriers_dict[locus_name].add(iid)

    # return ordered dictionary of locus -> cnv_carriers
    return locus_cnv_carriers_dict

def qq(obs_p, exp_p, n_perm, qq_plot_pdf, no_qq_logfile=False, 
       top_n_exclude=None):

    # generate -log10 p-values (observed and expected)
    log10p_observed = [log10(x)*-1 for x in list(obs_p)]
    log10p_expected = [log10(x)*-1 for x in list(exp_p)] 

    # if indicated by user, exclude top n
    if top_n_exclude != None:
        log10p_observed = log10p_observed[top_n_exclude:]
        log10p_expected = log10p_expected[top_n_exclude:] 

    #fit to linear model
    import statsmodels.api as sm
    y = log10p_observed
    x = log10p_expected
    lm = sm.OLS(y,x).fit()
    lambda_est = lm.params[0]
    lambda_se = lm.bse[0]

    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    ax.scatter(x, y, s=6)
    ax.text(min(x) * 0.05, max(y) * 1, 
            '\u03BB = '+str(round(lambda_est, 4)))
    ax.text(min(x) * 0.05, max(y) * 0.95, 
            '\u03BB standard error = '+\
            str(round(lambda_se, 4)))
    ax.text(min(x) * 0.05, max(y) * 0.90, 
            'n tests = ' + str(len(obs_p)))
    ax.text(min(x) * 0.05, max(y) * 0.85, 
            'n permutations = ' + str(n_perm))
    try:
        ax.axline((0, 0), slope=1., 
                  color='black', label='by slope')
    except:
        pass
    ax.set(xlabel="expected -log10(p)",
           ylabel="observed -log10(p)")
    fig.savefig(qq_plot_pdf)

    # unless specified by user, write lambda stats to logfile
    if no_qq_logfile == False:
        lambda_est = lm.params[0]
        lambda_se = lm.bse[0]
        out_fh = open(qq_plot_pdf + ".log", "w")
        out_fh.write("lambda_estimate:"+str(lambda_est) + "\n")
        out_fh.write("lambda_standard_error:"+str(lambda_se) + "\n")
        out_fh.write("n_tests:"+str(len(obs_p))+"\n")
        out_fh.write("n_permutations:"+str(n_perm)+"\n")
        out_fh.close()

    return

def get_multigene_cnvs(locus_pbt, cnv_pbt, n_gene_min=2):
    xcnvs1 = locus_pbt.intersect(cnv_pbt, wa=True, wb=True)
    interval_overlaps = dict()
    for feature in xcnvs1:
        locus_name_full=feature[3].split("_")
        locus_name=locus_name_full[0]
        interval = feature[7]
        if interval not in interval_overlaps:
            interval_overlaps[interval]=set([])
        interval_overlaps[interval].add(locus_name)
    intervals = list(interval_overlaps.keys())
    print(len(interval_overlaps))
    intervals_multigene_cnvs = set([])
    for interval in intervals:
        if len(interval_overlaps[interval]) >= n_gene_min:
            intervals_multigene_cnvs.add(interval)
    return intervals_multigene_cnvs

if __name__ == '__main__':
    main()


