
import argparse
import random
import numpy as np

def main():
    """
    Main block
    """
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--dels-only', action='store_true', default=False,
                        help='run analysis on deletions only.')
    parser.add_argument('--dups-only', action='store_true', default=False,
                        help='run analysis on duplications only.')
    parser.add_argument('--seed-number', type=int, default=None,
                        help='seed number for permutation procedure.')
    parser.add_argument('--size-min', type=int, default=30000,
                        help='minimum CNV size.')
    parser.add_argument('--size-max', type=int, default=200000000,
                        help='maximum CNV size.')
    parser.add_argument('--n-permutations', type=int, default=100,
                        help='number of permutations to use in testing procedure.')
    parser.add_argument('in_iid_group_pheno_tsv', 
                        help='Input TSV where column1 is individual ID, ' + \
                             'column2 is case/control group and column3 is ' + \
                             'phenotype (1=case, 0=control)')
    parser.add_argument('callset_overlaps_cds_clumped_bed', 
                         help='Input BED (supports "stdin"). Columns 1-6 are ' + 
                              'same as callset bed (chrom,start0,end,locus,cnvtype,iid) ' + \
                              'and column 7 is gene clump name')
    parser.add_argument('outroot', help='root filepath for results')
    args = parser.parse_args()

    # use seed number if one is defined by user
    if args.seed_number != None:
        random.seed(args.seed_number)

    # init structs for storing case/control status, call data
    iids = {"ALL":[]}
    phenotypes = {"ALL":[]}
    groups_set = set([])
    cnv_carriers = {"DEL":dict(), "DUP":dict()}

    # read iid / group / pheno data
    in_fh = open(args.in_iid_group_pheno_tsv, 'r')
    for line in in_fh:
        data = line.rstrip().split()
        [iid, group, pheno] = data[:3]
        if group not in groups_set:
            iids[group] = []
            phenotypes[group] = []
            groups_set.add(group)
        iids["ALL"].append(iid)
        iids[group].append(iid)
        phenotypes["ALL"].append(int(pheno))
        phenotypes[group].append(int(pheno))
    in_fh.close()

    # define list of groups
    groups_list = list(groups_set)

    # in callset overlaps_bed get cnv carriers per iid, split by del, dup
    n_tests = 0
    in_fh = open(args.callset_overlaps_cds_clumped_bed, "r")
    for line in in_fh:
        data = line.rstrip().split("\t")

        # get cnv size and skip if too big or too small
        start0 = int(data[1])
        end = int(data[2])
        length = end - start0
        if length < args.size_min: continue
        elif length > args.size_max: continue

        
        # skip if cnvtype is not in types to consider
        cnvtype = data[4]
        if cnvtype != "DUP" and args.dups_only == True: continue
        if cnvtype != "DEL" and args.dels_only == True: continue

        # get iid and region name and store observation
        iid = data[5]
        overlap_name = data[6]
        if overlap_name not in cnv_carriers[cnvtype]:
            cnv_carriers[cnvtype][overlap_name] = set()
            n_tests += 1
        cnv_carriers[cnvtype][overlap_name].add(iid)
    in_fh.close()

    # print number of unique test loci/cnvtype to stdout
    print("Number of seperate tests : "+str(n_tests))

    # get counts for :
    # 1. number of test loci with 1 case carrier and no control carriers
    # 2. number of test loci with 2 or more case carriers and no control carriers
    # 3. number of test loci where OR > 1 (using CMH framework) 
    n_loci_ca1_co0 = get_n_loci(iids['ALL'], phenotypes['ALL'], cnv_carriers,
                                nca_1_min=1,nca_1_max=1,
                                nco_1_min=0,nco_1_max=0)
    n_loci_cafrq_gt_cofrq = get_n_loci(iids['ALL'], phenotypes['ALL'], cnv_carriers,
                                       nca_1_min=2,
                                       freq_mode=True)

    # init lists for storing perm results
    n_loci_ca1_co0_l = []
    n_loci_cafrq_gt_cofrq_l = []

    # make copy of all phenos
    phenos = []
    for phe in phenotypes["ALL"]: phenos.append(phe)

    # for each permutation ..
    i = 1
    while i <= args.n_permutations:

        # print to stdout if i is divisable by 1000
        if i % 1000 == 0: print("permutation number : " + str(i))

        # for each group ..
        iids_perm = []
        phenos_perm = []
        for group in groups_list:
            # get iids, phenos
            iids_g = iids[group]
            phenos_g = phenotypes[group]
            # leave iids in order, but permute phenotypes
            iids_perm.extend(iids_g)
            random.shuffle(phenos_g)
            phenos_perm.extend(phenos_g)

        # calculate number of 1) loci with 1 case 0 ctrl, 2) loci w/ 2 case 0 ctrl
        n_loci_ca1_co0_i = get_n_loci(iids_perm, phenos_perm, cnv_carriers,
                                      nca_1_min=1,nca_1_max=1,
                                      nco_1_min=0,nco_1_max=0)
        n_loci_cafrq_gt_cofrq_i = get_n_loci(iids_perm, phenos_perm, cnv_carriers,
                                             nca_1_min=2,
                                             freq_mode=True)

        # add counts to lists
        n_loci_ca1_co0_l.append(n_loci_ca1_co0_i)
        n_loci_cafrq_gt_cofrq_l.append(n_loci_cafrq_gt_cofrq_i)

        # increment permutation counter
        i += 1

    # get rank for observed count amongst permutation derived counts
    n_loci_ca1_co0_r = test_val_rank(n_loci_ca1_co0, n_loci_ca1_co0_l)
    n_loci_cafrq_gt_cofrq_r = test_val_rank(n_loci_cafrq_gt_cofrq,
                                            n_loci_cafrq_gt_cofrq_l)

    # compute pval
    n_loci_ca1_co0_p = n_loci_ca1_co0_r / float(args.n_permutations)
    n_loci_cafrq_gt_cofrq_p = n_loci_cafrq_gt_cofrq_r / float(args.n_permutations)

    print("pval (nca 1, nco 0) : " + str(n_loci_ca1_co0_p))
    print("pval (ca frq > co frq) : " + str(n_loci_cafrq_gt_cofrq_p))

    # write permuted p-values to file
    out_fh = open(args.outroot + ".perm_counts.tsv","w")
    out_fh.write("n_loci_ca1_co0\tn_loci_cafrq_gt_cofrq\n")
    for i in range(args.n_permutations):
        row=[str(n_loci_ca1_co0_l[i]), str(n_loci_cafrq_gt_cofrq_l[i])]
        row_str="\t".join(row)
        out_fh.write(row_str+"\n")
    out_fh.close()

    # write results to file
    out_fh = open(args.outroot + ".results.tsv","w")
    out_fh.write("test\tn_obs\tn_perm_median\tperm_rank\tperm_p\n")
    out_fh.write("\t".join(["n_loci_ca1_co0",
                            str(n_loci_ca1_co0),
                            str(np.median(n_loci_ca1_co0_l)),
                            str(n_loci_ca1_co0_r),
                            str(n_loci_ca1_co0_p)])+"\n")
    out_fh.write("\t".join(["n_loci_cafrq_gt_cofrq",
                            str(n_loci_cafrq_gt_cofrq),
                            str(np.median(n_loci_cafrq_gt_cofrq_l)),
                            str(n_loci_cafrq_gt_cofrq_r),
                            str(n_loci_cafrq_gt_cofrq_p)])+"\n")
    return


def get_n_loci(iids, phenotypes, cnv_carriers,
               nca_1_min=0,nca_1_max=float('inf'),
               nco_1_min=0,nco_1_max=float('inf'),
               freq_mode=False):
    # load phenotypes
    cases=set()
    ctrls=set()
    for i in range(len(iids)):
        if phenotypes[i] == 1:
            cases.add(iids[i])
        elif phenotypes[i] == 0:
            ctrls.add(iids[i])
    n_cases = len(cases)
    n_ctrls = len(ctrls)
    n_loci_qual = 0 
    for cnvtype in cnv_carriers:
        for cluster_name in cnv_carriers[cnvtype]:
            carrier_iids = cnv_carriers[cnvtype][cluster_name]
            cases_1 = carrier_iids.intersection(cases)
            ctrls_1 = carrier_iids.intersection(ctrls)
            n_cases_1 = len(cases_1)
            n_ctrls_1 = len(ctrls_1)
            n_cases_0 = n_cases - n_cases_1
            n_ctrls_0 = n_ctrls - n_ctrls_1

            # skip locus if counts don't match user-specified qualifications
            if n_cases_1 < nca_1_min: continue
            elif n_ctrls_1 < nco_1_min: continue
            elif n_cases_1 > nca_1_max: continue
            elif n_ctrls_1 > nco_1_max: continue

            # use freq mode if user desires it
            if freq_mode == True:
                ca_freq = float(n_cases_1) / n_cases
                co_freq = float(n_ctrls_1) / n_ctrls
                if ca_freq <= co_freq: 
                    continue
            
            # count locus if it does meet user qual
            n_loci_qual += 1

    return n_loci_qual

def test_val_rank(test_number, number_list):
    """
    get number of permutations where test number is >= observed
    """
    number_list.sort()
    print(test_number)
    print(number_list[:10])
    print(number_list[-10:])
    i = 0
    rank = 0
    for number in number_list:
        if number >= test_number:
            break
        rank += 1
    rank1 = len(number_list) - rank
    return rank1

if __name__ == "__main__":
    main()
