
import h5py
import argparse
import pybedtools
from collections import defaultdict
from collections import OrderedDict
import matplotlib.pyplot as plt
from numpy import isnan
import numpy 
from scipy.stats import gaussian_kde
import seaborn as sb
from sklearn.neighbors import KernelDensity
from scipy.signal import find_peaks
from sklearn.cluster import MeanShift
from sklearn.metrics import silhouette_score
import pandas 

def main():
    args = argparse.ArgumentParser()
    args.add_argument("--hdf5-write-mode", action="store", type=str,
                      default="w", choices=['w','a'],
                      help="file write mode ('w' for write, or 'a' for " + \
                           "append). 'Write' mode will overwrite existing " + \
                           "file, while 'append' mode will add data to " + \
                           "already-existing file.")
    args.add_argument("--chroms", action="store", type=str, 
                      default="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22",
                      help="set of chromosomes to store data for.")
    args.add_argument("--intensity-file-suffix", action="store", type=str,
                      default=".txt.adjusted.gz",
                      help="suffix that intensity files end with, for removal.")
    args.add_argument("--intensity-filename-delim",action="store",type=str,
                      default=".", 
                      help="delimiter of intensity filename, for extracting " +\
                           "sampleid.")
    args.add_argument("--intensity-filename-sampleid-pos",action="store",type=int,
                      default=3, 
                      help="position of sampleid in intensity filename, " + \
                           "delim by --intensity-filename-delim.")
    args.add_argument("--gzip-bed", action="store_true", default=False,
                      help="gzip the output BED file.")
    args.add_argument("--hom-baf-thresh", action="store", default=0.1,
                      type=float,
                      help="threshold for calling BAF homozygous")
    args.add_argument("--del-perc-hom-thresh", action="store", type=float,
                      default=None,
                      help="threshold for required percent of " + \
                           "vars in deletion locus to be homozygous")
    args.add_argument("--stop-loading-after-n-samples", default=None, type=int,
                      help="Stop loading after n samples. For testing purposes.")
    args.add_argument("--print-current-cnv", default=False, 
                      action='store_true', 
                      help='print iid / locus for current CNV to stdout.')
    args.add_argument("--skip-dels", default=False, 
                      action='store_true', 
                      help='skip BAF validation in deletion CNV calls.')
    args.add_argument("--hdf5-lrr-key", default="lrr", type=str,
                      help="key for lrr values in hdf5 file")
    args.add_argument("--hdf5-baf-key", default="baf", type=str,
                      help="key for baf values in hdf5 file")
    args.add_argument('--matplotlib-use', type=str, default=None,
                      help='non-default engine for matplotlib to use.')
    args.add_argument('--make-lrr-baf-plot', default=False, action="store_true",
                      help='for each cnv, plot lrr and baf')
    args.add_argument('--make-baf-density-plot', default=False, action="store_true",
                      help='for each cnv, plot BAFs and gaussian KDE of BAFs')
    args.add_argument('--baf-cluster-calling-method', default='k-medoids',
                      action='store',
                      choices=['meanshift','k-medoids'],
                      help='method used for calling cluster centers in BAFs')
    args.add_argument('--out-filtered-cnv', type=str, default=None,
                      action='store',
                      help='file to write filtered cnv callset to')
    args.add_argument("in_hdf5", type=str,
                      help="input hdf5 file with intensity data.")
    args.add_argument("in_cnv", type=str,
                      help="input cnv file, in either penncnv or bed format.")
    args.add_argument("outroot", type=str,
                      help="output file with metrics per cnv call.")
    args = args.parse_args()

    # make sure that matplotlib is correctly setup
    if args.matplotlib_use != None:
        import matplotlib
        matplotlib.use(args.matplotlib_use)

    # create filehandle for reading of hdf5
    h5_fh = h5py.File(args.in_hdf5, "r")
    
    # get group and dataset per iid
    # nongroups = args.nongroups.split(",")
    iid_group_dataset = dict()
    for group in h5_fh.keys():
        for dataset in h5_fh[group].keys():
            if isinstance(h5_fh[group][dataset], h5py.Group):
                for iid in h5_fh[group][dataset].keys():
                    iid_group_dataset[iid] = (group, dataset)
    
    # get LOCI
    marker_loci_beds = dict()
    marker_loci_chroms = []
    marker_loci_pos = []
    marker_loci_idx = dict()
    marker_loci_chrom_startend = dict()
    
    # dict to store results
    res_dict = {"group": [],
                "dataset": [],
                "iid": [],
                "interval": [],
                "cnvtype": [],
                "cnvlength": [],
                "nsnp": [],
                "nsnp_nan": [],
                "het": [],
                "hom": [],
                "n_baf_peaks": [],
                "baf_peak_centers": [],
                "passfail": []}
    
    for group in h5_fh.keys():
        print("Loading array marker info ..")
        marker_startend = h5_fh[group]['CHROM_IDX']
        for i in range(marker_startend.shape[0]): 
            chrom = marker_startend[i, 0].decode("ascii")
            start = int(marker_startend[i, 1].decode("ascii"))
            end = int(marker_startend[i, 2].decode("ascii"))
            marker_loci_chrom_startend[chrom] = [start, end]

        # derive 1D arrays witih the following:
        # 1. chroms
        # 2. positions
        # 3. markerids
        marker_loci = h5_fh[group]['LOCI']  
        marker_loci_chr = marker_loci[:, 0].astype('U13')
        marker_loci_pos = marker_loci[:, 1].astype(int)
        marker_loci_ids = marker_loci[:, 2].astype('U13')

        # form tree with chrom->position->idx
        for i in range(len(marker_loci_chr)):
            if marker_loci_chr[i] not in marker_loci_idx:
                marker_loci_idx[marker_loci_chr[i]] = dict()
            marker_loci_idx[marker_loci_chr[i]][marker_loci_pos[i]] = i
 
        print("done.")

        # open filehandle to cnv
        cnv_fh = open(args.in_cnv, "r")
        i = 0
        for line in cnv_fh:
            i += 1
            data = line.rstrip().split()
            chrom = data[0]
            chrom = chrom.replace("chr","")
            start0 = int(data[1])
            start = start0 + 1
            end = int(data[2])
            interval = data[3]
            cnvtype = data[4]
            iid = data[5]

            # print current iid / locus to stdout
            if args.print_current_cnv: print(str(i)+":"+iid + ", " + interval)

            # if desired by user, skip dels
            if args.skip_dels == True:
                if cnvtype == "DEL": continue

            # get dataset, group
            (group, dataset) = iid_group_dataset[iid]

            # get start and end idx
            start_idx = marker_loci_idx[chrom][start]
            end_idx = marker_loci_idx[chrom][end]
            positions = marker_loci_pos[start_idx:end_idx]

            # get bafs
            baf_x = h5_fh[group][dataset][iid][args.hdf5_baf_key][start_idx:end_idx]
            baf_y = list(baf_x)
            baf_y.sort()

            # get lrrs
            chrom_start_idx = marker_loci_chrom_startend[chrom][0]
            chrom_end_idx = marker_loci_chrom_startend[chrom][1]
            lrr_vals = h5_fh[group][dataset][iid][args.hdf5_lrr_key][chrom_start_idx:chrom_end_idx]
            lrr_x = h5_fh[group][dataset][iid][args.hdf5_lrr_key][start_idx:end_idx]
            lrr_vals = lrr_vals[isnan(lrr_vals)==False]
            lrr_mean = numpy.mean(lrr_vals)
            lrr_sd = numpy.std(lrr_vals)
            lrr_z_mean = numpy.mean((lrr_x - lrr_mean) / lrr_sd)

            # get bin counts, lrr z mean (noncarriers)
            """
            bin_counts_0 = [0] * len(bin_counts)
            lrr_0_z_means = []
            for iid_0 in h5_fh[group][dataset].keys():
                if iid_0 == iid: continue
                baf_0_x = h5_fh[group][dataset][iid_0]['baf'][start_idx:end_idx]
                bin_counts_0_x = baf_bins(baf_0_x)
                nan_count_0_x = bin_counts_0_x[-1]
                bin_counts_0_x = bin_counts_0_x[:-1]
                for i in range(len(bin_counts_0_x)):
                    bin_counts_0[i] += bin_counts_0_x[i]
                # lrr_0_x = h5_fh[group][dataset][iid_0]['lrr'][start_idx:end_idx]
                # lrr_0_x = lrr_0_x[isnan(lrr_0_x)==False]
                # lrr_0_vals = h5_fh[group][dataset][iid_0]['lrr'][chrom_start_idx:chrom_end_idx]
                # lrr_0_vals = lrr_0_vals[isnan(lrr_0_vals)==False]
                # lrr_0_mean = numpy.mean(lrr_0_vals)
                # lrr_0_sd = numpy.std(lrr_0_vals)
                # lrr_0_z_mean = numpy.mean((lrr_0_x - lrr_0_mean) / lrr_0_sd)
                # lrr_0_z_means.append(lrr_0_z_mean)
            """

            # compute sumstats :
            # 1. noncarrier hom rate
            # 2. Z score of mean in locus for Z-transformed LLR
            # noncarrier_hom = (bin_counts_0[0] + bin_counts_0[-1]) / float(sum(bin_counts_0))
            # lrr_z = (lrr_z_mean - numpy.mean(lrr_0_z_means)) / numpy.std(lrr_0_z_means)

            # compute heterozygosity/homozygosity
            # het = float(bin_counts[2])/len(baf_y)
            # hom = float(bin_counts[0] + bin_counts[-1])/len(baf_y)
            # del_pass = (het < 0.05) and (hom > 0.95) and cnvtype == "DEL"
            # dup_pass = (het < 0.05) and cnvtype == "DUP"
 
            # get  BAF values, remove nan vals
            baf_x_vals = list(baf_x)
            baf_x_nonan = numpy.copy(baf_x)
            baf_x_nonan = baf_x_nonan[isnan(baf_x_nonan)==False]
            
            # get n vars where BAF is nan
            nan_count = len(baf_x) - len(baf_x_nonan)
            baf_x_nonan_vals = list(baf_x_nonan)

            # get n vars where BAF > hom_baf_thresh or < (1-hom_baf_thresh),
            # define as homozygosity
            baf_x_nonan_hom = numpy.copy(baf_x_nonan)
            baf_x_nonan_hom = baf_x_nonan_hom[(baf_x_nonan_hom > args.hom_baf_thresh) | (baf_x_nonan_hom < (1-args.hom_baf_thresh))]
            hom_count = len(baf_x_nonan_hom)
            hom_perc = float(hom_count) / len(baf_x_nonan)
            het_perc = 1 - hom_perc

            # silhouette_score(X, labels, *, metric='euclidean',
            # sample_size=None, random_state=None, **kwds)[source]

            # get cluster centers of BAF values
            baf_x_x = baf_x.reshape(-1, 1)
            baf_x_x_nonan = baf_x_nonan.reshape(-1, 1)
            
            # derive best-fit number of clusters and cluster centers
            if args.baf_cluster_calling_method == 'k-medoids':
                from sklearn.cluster import KMeans
                from sklearn import metrics
                silhouette_score_max = -2
                silhouette_score_max_k = 1
                silhouette_score_max_mdl = None
                for k in (2,3,4,5):    
                    mdl = KMeans(n_clusters=k, random_state=1).fit(baf_x_x_nonan)
                    labels = mdl.labels_
                    sil_score = metrics.silhouette_score(baf_x_x_nonan, labels, metric='euclidean')
                    if sil_score > silhouette_score_max:
                        silhouette_score_max = sil_score
                        silhouette_score_max_k = k
                        silhouette_score_max_mdl = mdl
                mdl = silhouette_score_max_mdl
            else:
                mdl = MeanShift(bandwidth=0.15).fit(baf_x_x_nonan)

            # get cluster centers
            cluster_centers_str = ";".join([str(x[0]) for x in mdl.cluster_centers_])
            baf_clusters = baf_bins(mdl.cluster_centers_, 
                                    bins=[0, 0.333, 0.5, 0.666, 1],
                                    return_nan_counts=False)
            n_baf_peaks = sum(baf_clusters)

            # declare cnvs as pass or fail based on BAF metrics
            ispass = False
            if cnvtype == "DEL":
                if args.del_perc_hom_thresh != None:
                    ispass = (hom_perc > args.del_perc_hom_thresh) and \
                               (n_baf_peaks == 2) 
                else:
                    ispass = (n_baf_peaks == 2)
            elif cnvtype =="DUP":
                ispass = (baf_clusters[2]==0)
            if ispass == True:
                passfail_str = "PASS"
            else:
                passfail_str = "FAIL"

            # make LRR/BAF plot
            if args.make_lrr_baf_plot == True:
                plt.figure()
                fig, axs = plt.subplots(2, sharex=True,
                                        figsize=(10, 5))

                # adjust spacing
                fig.tight_layout()
                
                if cnvtype == "DEL":
                    axs[1].plot(positions, baf_x, '.b')
                elif cnvtype == "DUP":
                    axs[1].plot(positions, baf_x, '.g')
                axs[1].axhline(y=0)
                axs[1].axhline(y=0.333)
                axs[1].axhline(y=0.5)
                axs[1].axhline(y=0.666)
                axs[1].axhline(y=1)
                 
                if cnvtype == "DEL":
                    axs[0].plot(positions, lrr_x, '.b')
                elif cnvtype == "DUP":
                    axs[0].plot(positions, lrr_x, '.g')
                axs[0].set_ylim(-2, 2)
                axs[0].axhline(y=0)
                axs[0].axhline(y=numpy.mean(lrr_x), color='red')
                # axs[0].axhline(y=numpy.mean(lrr_0_z_means), color='green')
                plt.savefig(args.outroot + \
                            "." + ".".join([passfail_str, group, dataset, iid,
                                            chrom, str(start), str(end), "pdf"]))

                # close plot handle
                plt.close('all')
                

            if args.make_baf_density_plot == True:
                plt.figure()
                fig, axs = plt.subplots(2, sharex=False,
                                        figsize=(10, 5))
                # adjust spacing
                fig.tight_layout()
        
                # plot BAF values
                if cnvtype == "DEL":
                    axs[0].plot(positions, baf_x, '.b')
                elif cnvtype == "DUP":
                    axs[0].plot(positions, baf_x, '.g')
                axs[0].axhline(y=0.333)
                axs[0].axhline(y=0.5)
                axs[0].axhline(y=0.666)
                axs[0].axhline(y=1)

                # init gaussian kernel density plot of BAF values
                axs[1] = sb.kdeplot(baf_x_nonan_vals, bw_adjust=0.15, fill = True)
                
                # insert cluster centers into plot as vertical lines
                for i in mdl.cluster_centers_:
                    axs[1].axvline(x=i)

                plt.savefig(args.outroot + \
                            "." + ".".join([passfail_str, group, dataset, iid,
                                            chrom, str(start), str(end), "density.pdf"]))


                # close pyplot intance
                plt.close('all')
                

            # write to output 
            res_dict["group"].append(iid_group_dataset[iid][0])
            res_dict["dataset"].append(iid_group_dataset[iid][1])
            res_dict["iid"].append(iid)
            res_dict["interval"].append(interval)
            res_dict["cnvtype"].append(cnvtype)
            res_dict["cnvlength"].append(str(end - start0))
            res_dict["nsnp"].append(str(len(baf_y)))
            res_dict["nsnp_nan"].append(str(nan_count))
            res_dict["het"].append(str(het_perc))
            res_dict["hom"].append(str(hom_perc))
            res_dict["n_baf_peaks"].append(str(n_baf_peaks))
            res_dict["baf_peak_centers"].append(cluster_centers_str)
            res_dict["passfail"].append(passfail_str)

    # convert results dict to dataframe
    res_df = pandas.DataFrame.from_dict(res_dict)
        
    # write results dict to tsv
    res_df.to_csv(path_or_buf=args.outroot + ".stats.tsv",
                  sep="\t", index=False,
                  header=True)

    # close cnv filehandle
    cnv_fh.close()

    # if desired by user, write filtered callset to file
    if args.out_filtered_cnv != None:
        # get iid/interval
        iid_interval_set = set([])
        for i in range(len(res_dict["iid"])):
            iid_i = res_dict["iid"][i]
            passfail_i = res_dict["passfail"][i]
            interval_i = res_dict["interval"][i]
            if passfail_i == "PASS":
                iid_interval_set.add(iid_i + " " + interval_i)
        
        # open filehandle to cnv
        cnv_fh = open(args.in_cnv, "r")
        out_fh = open(args.out_filtered_cnv, "w")
        for line in cnv_fh:
            line = line.rstrip()
            data = line.split()
            interval = data[3]
            iid = data[5]
            iid_interval =  iid + " " + interval
            if iid_interval in iid_interval_set:
                out_fh.write(line + "\n")
        
        # close filehandles
        out_fh.close()
        cnv_fh.close()

    return

def baf_bins(bafs, 
             bins=[0, 0.333, 0.5, 0.666, 1], 
             return_nan_counts=False):
    bin_counts = [0] * len(bins)
    n_nan = 0
    for i in range(len(bafs)):
        if isnan(bafs[i]):
            n_nan += 1
            continue
        binmin = float("inf")
        bin_idx = None
        for j in range(len(bins)):
            bindist = abs(bafs[i] - bins[j])
            if bindist < binmin:
                binmin = bindist
                bin_idx = j
        bin_counts[bin_idx] += 1
    if return_nan_counts == True:
        bin_counts.append(n_nan)
    return bin_counts

if __name__ == "__main__":
    main()
