
import h5py
import argparse
import numpy 
import pandas 

def main():
    args = argparse.ArgumentParser()
    args.add_argument("--hdf5-write-mode", action="store", type=str,
                      default="w", choices=['w','a'],
                      help="file write mode ('w' for write, or 'a' for " + \
                           "append). 'Write' mode will overwrite existing " + \
                           "file, while 'append' mode will add data to " + \
                           "already-existing file.")
    args.add_argument("--chroms", action="store", type=str, 
                      default="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22",
                      help="set of chromosomes to store data for.")
    args.add_argument("--intensity-file-suffix", action="store", type=str,
                      default=".txt.adjusted.gz",
                      help="suffix that intensity files end with, for removal.")
    args.add_argument("--intensity-filename-delim",action="store",type=str,
                      default=".", 
                      help="delimiter of intensity filename, for extracting " +\
                           "sampleid.")
    args.add_argument("--intensity-filename-sampleid-pos",action="store",type=int,
                      default=3, 
                      help="position of sampleid in intensity filename, " + \
                           "delim by --intensity-filename-delim.")
    args.add_argument("--gzip-bed", action="store_true", default=False,
                      help="gzip the output BED file.")
    args.add_argument("--hom-baf-thresh", action="store", default=0.1,
                      type=float,
                      help="threshold for calling BAF homozygous")
    args.add_argument("--del-perc-hom-thresh", action="store", type=float,
                      default=None,
                      help="threshold for required percent of " + \
                           "vars in deletion locus to be homozygous")
    args.add_argument("--stop-loading-after-n-samples", default=None, type=int,
                      help="Stop loading after n samples. For testing purposes.")
    args.add_argument("--stop-running-after-n-calls", default=None, type=int,
                      help="Stop loading after n cnv calls. For testing purposes.")
    args.add_argument("--print-current-cnv", default=False, 
                      action='store_true', 
                      help='print iid / locus for current CNV to stdout.')
    args.add_argument("--hdf5-lrr-key", default="lrr", type=str,
                      help="key for lrr values in hdf5 file")
    args.add_argument("--hdf5-baf-key", default="baf", type=str,
                      help="key for baf values in hdf5 file")
    args.add_argument('--matplotlib-use', type=str, default=None,
                      help='non-default engine for matplotlib to use.')
    args.add_argument('--make-lrr-baf-plot', default=False, action="store_true",
                      help='for each cnv, plot lrr and baf')
    args.add_argument('--make-baf-density-plot', default=False, action="store_true",
                      help='for each cnv, plot BAFs and gaussian KDE of BAFs')
    args.add_argument('--baf-cluster-calling-method', default='k-medoids',
                      action='store',
                      choices=['meanshift','k-medoids','k-means'],
                      help='method used for calling cluster centers in BAFs')
    args.add_argument('--out-metrics-tsv', type=str, default=None,
                      action='store',
                      help='file to write metrics to')
    args.add_argument("in_hdf5", type=str,
                      help="input hdf5 file with intensity data.")
    args.add_argument("in_cnv", type=str,
                      help="input cnv file, in either penncnv or bed format.")
    args.add_argument("out_bed", type=str,
                      help="output bed file.")
    args = args.parse_args()

    # make sure that matplotlib is correctly setup
    if args.matplotlib_use != None:
        import matplotlib
        matplotlib.use(args.matplotlib_use)

    # create filehandle for reading of hdf5
    h5_fh = h5py.File(args.in_hdf5, "r")
    
    # get group and dataset per iid
    # nongroups = args.nongroups.split(",")
    iid_group_dataset = dict()
    for group in h5_fh.keys():
        for dataset in h5_fh[group].keys():
            if isinstance(h5_fh[group][dataset], h5py.Group):
                for iid in h5_fh[group][dataset].keys():
                    iid_group_dataset[iid] = (group, dataset)
    
    # get LOCI
    marker_loci_beds = dict()
    marker_loci_chr = dict()
    marker_loci_pos = dict()
    marker_loci_ids = dict() 
    marker_loci_idx = dict()
    marker_loci_chrom_startend = dict()
    
    # dict to store results
    res_dict = {"group": [],
                "dataset": [],
                "iid": [],
                "interval": [],
                "cnvtype": [],
                "cnvlength": [],
                "nsnp": [],
                "nsnp_nan": [],
                "het": [],
                "hom": [],
                "n_baf_peaks": [],
                "baf_peak_centers": [],
                "passfail": []}
    
    for group in h5_fh.keys():
        # print("Loading array marker info .. ("+ group + ")")
        marker_loci_chrom_startend[group] = dict()
        marker_startend = h5_fh[group]['CHROM_IDX']
        for i in range(marker_startend.shape[0]): 
            chrom = marker_startend[i, 0].decode("ascii")
            start = int(marker_startend[i, 1].decode("ascii"))
            end = int(marker_startend[i, 2].decode("ascii"))
            marker_loci_chrom_startend[group][chrom] = [start, end]

        # derive 1D arrays witih the following:
        # 1. chroms
        # 2. positions
        # 3. markerids
        marker_loci = h5_fh[group]['LOCI']  
        marker_loci_chr[group] = marker_loci[:, 0].astype('U13')
        marker_loci_pos[group] = marker_loci[:, 1].astype(int)
        marker_loci_ids[group] = marker_loci[:, 2].astype('U13')

        # form tree with chrom->position->idx
        marker_loci_idx[group] = dict()
        for i in range(len(marker_loci_chr[group])):
            if marker_loci_chr[group][i] not in marker_loci_idx[group]:
                marker_loci_idx[group][marker_loci_chr[group][i]] = dict()
            marker_loci_idx[group][marker_loci_chr[group][i]][marker_loci_pos[group][i]] = i
 
        # print("done.")
    # print("all marker data loaded.")

    # structs for storing output
    out_bed_lines = []
    out_stats_lines = []

    # open filehandle to cnv
    cnv_fh = open(args.in_cnv, "r")
    cnv_call_i = 0
    for line in cnv_fh:
        if line[0] == "#": continue
        line = line.rstrip()
        data = line.split()
        chrom = data[0]
        chrom = chrom.replace("chr","")
        start0 = int(data[1])
        start = start0 + 1
        end = int(data[2])
        interval = data[3]
        cnvtype = data[4]
        iid = data[5]
        length = end - start0 

        # print current iid / locus to stdout
        if args.print_current_cnv: print(iid + ", " + interval)

        # get dataset, group
        (group, dataset) = iid_group_dataset[iid]

        # get start and end idx
        start_idx = marker_loci_idx[group][chrom][start]
        end_idx = marker_loci_idx[group][chrom][end]
        positions = marker_loci_pos[group][start_idx:end_idx]

        # get bafs
        baf_x = h5_fh[group][dataset][iid][args.hdf5_baf_key][start_idx:end_idx]

        # get likely hom sites
        baf_x = baf_x[numpy.isnan(baf_x)==False]
        het_x = baf_x[(baf_x > 0.2) & (baf_x < 0.8)]
        # het_x_str = ",".join([str(x) for x in het_x])

        # count number of likely hom and het sites in region
        n_x = len(baf_x)
        n_het_x = len(het_x)
        n_hom_x = n_x - n_het_x

        # get the fraction of sites that are het
        perc_het = float(n_het_x) / n_x
        probe_density = float(n_x) / length
        out_list=[group, dataset, iid, interval, cnvtype, 
                  str(length), str(n_x), str(probe_density),
                  str(n_hom_x), str(n_het_x)] 
        if perc_het < 0.05 and cnvtype == "DEL":
            out_bed_lines.append(line)
            passfail="PASS"
        elif perc_het > 0.05 and cnvtype == "DUP":
            out_bed_lines.append(line)
            passfail="PASS"
        else:
            passfail="FAIL"
        out_list.append(passfail)
        out_str = "\t".join(out_list)
        out_stats_lines.append(out_str)



    cnv_fh.close()
    h5_fh.close()

    # write files
    out_fh = open(args.out_bed, "w")
    for line in out_bed_lines:
        out_fh.write(line + "\n")
    out_fh.close()
    if args.out_metrics_tsv != None:
        out_fh = open(args.out_metrics_tsv, "w")
        out_fh.write("\t".join(["group","dataset","iid","interval","cnvtype",
                                "length","n_probes","probe_density",
                                "n_hom","n_het","passfail"]) + "\n")
        for line in out_stats_lines:
            out_fh.write(line + "\n")
        out_fh.close()

    return
 
def baf_bins(bafs, 
             bins=[0, 0.333, 0.5, 0.666, 1], 
             return_nan_counts=False):
    bin_counts = [0] * len(bins)
    n_nan = 0
    for i in range(len(bafs)):
        if numpy.isnan(bafs[i]):
            n_nan += 1
            continue
        binmin = float("inf")
        bin_idx = None
        for j in range(len(bins)):
            bindist = abs(bafs[i] - bins[j])
            if bindist < binmin:
                binmin = bindist
                bin_idx = j
        bin_counts[bin_idx] += 1
    if return_nan_counts == True:
        bin_counts.append(n_nan)
    return bin_counts

if __name__ == "__main__":
    main()
