
import h5py
import argparse
import pybedtools
from collections import defaultdict
from collections import OrderedDict
import matplotlib.pyplot as plt
from numpy import isnan
import numpy 
import sys
from sklearn.preprocessing import QuantileTransformer
from scipy.interpolate import UnivariateSpline

def main():
    args = argparse.ArgumentParser()
    args.add_argument("--hdf5-write-mode", action="store", type=str,
                      default="w", choices=['w','a'],
                      help="file write mode ('w' for write, or 'a' for " + \
                           "append). 'Write' mode will overwrite existing " + \
                           "file, while 'append' mode will add data to " + \
                           "already-existing file.")
    args.add_argument("--chroms", action="store", type=str, 
                      default="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22",
                      help="set of chromosomes to store data for.")
    args.add_argument("--intensity-file-snpid-colnum", action="store", type=int,
                      default=1,
                      help="column number in intensity files for SNP ID.")
    args.add_argument("--intensity-file-baf-colnum", action="store", type=int,
                      default=5,
                      help="column number in intensity files for B allele freq.")
    args.add_argument("--intensity-file-lrr-colnum", action="store", type=int,
                      default=4,
                      help="column number in intensity files for log R ratio.")
    args.add_argument("--intensity-file-suffix", action="store", type=str,
                      default=".txt.adjusted.gz",
                      help="suffix that intensity files end with, for removal.")
    args.add_argument("--intensity-filename-delim",action="store",type=str,
                      default=".", 
                      help="delimiter of intensity filename, for extracting " +\
                           "sampleid.")
    args.add_argument("--intensity-filename-sampleid-pos",action="store",type=int,
                      default=3, 
                      help="position of sampleid in intensity filename, " + \
                           "delim by --intensity-filename-delim.")
    args.add_argument("--gzip-bed", action="store_true", default=False,
                      help="gzip the output BED file.")
    args.add_argument("--stop-loading-after-n-samples", default=None, type=int,
                      help="Stop loading after n samples. For testing purposes.")
    args.add_argument("in_hdf5", type=str,
                      help="input hdf5 file with intensity data.")
    args.add_argument("in_cnv", type=str,
                      help="input cnv file, in either penncnv or bed format.")
    args.add_argument("outroot", type=str,
                      help="output file with metrics per cnv call.")
    args = args.parse_args()

    # create filehandle for reading of hdf5
    h5_fh = h5py.File(args.in_hdf5, "r")
    
    # get group and dataset per iid
    # nongroups = args.nongroups.split(",")
    iid_group_dataset = dict()
    for group in h5_fh.keys():
        for dataset in h5_fh[group].keys():
            if isinstance(h5_fh[group][dataset], h5py.Group):
                for iid in h5_fh[group][dataset].keys():
                    iid_group_dataset[iid] = (group, dataset)
    
    # get LOCI
    marker_loci_beds = dict()
    marker_loci_chroms = []
    marker_loci_pos = []
    marker_loci_idx = dict()
    marker_loci_chrom_startend = dict()
    for group in h5_fh.keys():
        marker_loci = h5_fh[group]['LOCI']
        marker_loci_bed = []
        print("loading BED .. ")
        for i in range(marker_loci.shape[0]):
            chrom = marker_loci[i, 0].decode("ascii")
            if chrom not in marker_loci_chrom_startend:
                marker_loci_chrom_startend[chrom] = [i, i]
            else:
                marker_loci_chrom_startend[chrom][1] = i
            if chrom != "1" and chrom != "2": continue
            start = int(marker_loci[i, 1].decode("ascii"))
            start0 = str(start - 1)
            end = marker_loci[i, 1].decode("ascii")
            markerid = marker_loci[i, 2].decode("ascii")
            if chrom not in marker_loci_idx:
                marker_loci_idx[chrom] = dict()
            marker_loci_idx[chrom][start] = i
            marker_loci_chroms.append(chrom)
            marker_loci_pos.append(start)
            # marker_loci_bed.append("\t".join(["chr"+chrom, start0, end, 
            #                                   markerid, str(i)]))
        print("done.")
    
        # create instance of pybedtools for marker_loci
        # marker_loci_beds[group] = pybedtools.BedTool("\n".join(marker_loci_bed),
        #                                              from_string=True)

    # CNV deletion call in PennCNV+QuantiSNP
    # chr1:143244433-143541389
    iid = 'OCTS_178_403'
    
    # chr1:247986129-248577691
    iid = "TDT08-1047-03"

    # chr2:50798055-50951081
    iid = "OCTS_156_403"
    chrom = '2'
    start = 50798055 
    end = 50951081


    start_idx = marker_loci_chrom_startend[chrom][0]
    end_idx = marker_loci_chrom_startend[chrom][1]
    print(end_idx)
    # start_idx=20000
    # end_idx=30000
    group="GSA"
    dataset="TS_GSA_for_CNV"
    lrr_vals = h5_fh[group][dataset][iid]['lrr'][start_idx:end_idx]
    pos_vals = marker_loci_pos[start_idx:end_idx]
    plt.plot(pos_vals, lrr_vals, ".g")
    plt.axvline(x=start)
    plt.axvline(x=end)
    plt.show()
    
    lrr_vals_q =  normalize_lrr_values(lrr_vals,
                                       expectedmean=0, penalty=60, quantile=True, 
                                       qspline=False, 
                                       sd=numpy.std(lrr_vals), 
                                       recenter=True,
                                       n_samplings=50)
    

    plt.plot(pos_vals, lrr_vals_q, ".g")
    plt.axvline(x=start)
    plt.axvline(x=end)
 
    plt.show()
 

    # use PELT on quantized LRR values
    import ruptures as rpt
    model='rbf'
    algo = rpt.Pelt(model=model, min_size=20, jump=50)
    # algo.fit(lrr_vals_q)
    print(lrr_vals_q)
    pen_x=10
    quit()
    result = algo.predict(pen=pen_x)
    rpt.display(lrr_vals_q, result, figsize=(10, 6))
    plt.title('Change Point Detection: Pelt Search Method')
    plt.show()  
    quit()

    sys.exit(1)

    # print header to stdout
    header_list = ['cnvtype', 'cnvlength', 'nsnp',
                   'n_nan', 'n_BAF_0', 'n_BAF_33',
                   'n_BAF_50', 'n_BAF_66', 'n_BAF_100',
                   'het','hom']
    header = "\t".join(header_list)
    out_fh = open(args.outroot + ".stats.tsv", "w")
    out_fh.write(header + "\n")

    # open filehandle to cnv
    cnv_fh = open(args.in_cnv, "r")
    for line in cnv_fh:
        data = line.rstrip().split()
        chrom = data[0]
        chrom = chrom.replace("chr","")
        if chrom != "1" and chrom != "2": continue
        start0 = int(data[1])
        start = start0 + 1
        end = int(data[2])
        interval = data[3]
        cnvtype = data[4]
        iid = data[5]

        # get dataset, group
        (group, dataset) = iid_group_dataset[iid]

        # get start and end idx
        start_idx = marker_loci_idx[chrom][start]
        end_idx = marker_loci_idx[chrom][end]
        positions = marker_loci_pos[start_idx:end_idx]

        # TEST : skip if < 15 positions
        if len(positions) < 15:
            continue

        # get bafs
        baf_x = h5_fh[group][dataset][iid]['baf'][start_idx:end_idx]
        baf_y = list(baf_x)
        baf_y.sort()

        # get lrrs
        # lrr_vals = h5_fh[group][dataset][iid]['lrr']
        lrr_x = h5_fh[group][dataset][iid]['lrr'][start_idx:end_idx]
        # lrr_vals = lrr_vals[isnan(lrr_vals)==False]
        # lrr_mean = numpy.mean(lrr_vals)
        # lrr_sd = numpy.std(lrr_vals)
        # lrr_x = (lrr_x - lrr_mean) / lrr_sd
        # lrr_y = list(lrr_x)
        # lrr_y.sort()

        # get bin counts
        bin_counts = baf_bins(baf_x)
        nan_count = bin_counts[-1]
        bin_counts = bin_counts[:-1]

        # make plot
        plt.figure()
        fig, axs = plt.subplots(2, sharex=True,
                                figsize=(10, 5))

        # adjust spacing
        fig.tight_layout()
        
        if cnvtype == "DEL":
            axs[1].plot(positions, baf_x, '.b')
        elif cnvtype == "DUP":
            axs[1].plot(positions, baf_x, '.g')
        axs[1].axhline(y=0)
        axs[1].axhline(y=0.1665)
        axs[1].axhline(y=0.415)
        axs[1].axhline(y=0.583)
        axs[1].axhline(y=0.833)
        axs[1].axhline(y=1)
        
        if cnvtype == "DEL":
            axs[0].plot(positions, lrr_x, '.b')
        elif cnvtype == "DUP":
            axs[0].plot(positions, lrr_x, '.g')
        axs[0].set_ylim(-2, 2)
        axs[0].axhline(y=0)
        axs[0].axhline(y=1)
        axs[0].axhline(y=-1)
 
        if float(bin_counts[2])/len(baf_y) > 0.05:
            plt.savefig(args.outroot + \
                        "." + ".".join(["FAIL", group, dataset, iid,
                                        chrom, str(start), str(end), "pdf"]))

        else:
            plt.savefig(args.outroot + \
                        "." + ".".join(["PASS", group, dataset, iid,
                                        chrom, str(start), str(end), "pdf"]))



        # write to output 
        res_list = [cnvtype, str(end - start0),
                    str(len(baf_y)), str(nan_count),
                    "\t".join([str(x) for x in bin_counts]),
                    str(float(bin_counts[2])/len(baf_y)),
                    str(float(bin_counts[0]+bin_counts[-1])/len(baf_y))
                   ]
        res_str = "\t".join(res_list)
        out_fh.write(res_str + "\n")

    # close output filehandle
    out_fh.close()

    return

def baf_bins(bafs, bins=[0, 0.333, 0.5, 0.666, 1]):
    bin_counts = [0] * len(bins)
    n_nan = 0
    for i in range(len(bafs)):
        if isnan(bafs[i]):
            n_nan += 1
            continue
        binmin = float("inf")
        bin_idx = None
        for j in range(len(bins)):
            bindist = abs(bafs[i] - bins[j])
            if bindist < binmin:
                binmin = bindist
                bin_idx = j
        bin_counts[bin_idx] += 1
    bin_counts.append(n_nan)
    return bin_counts

def quantileNormalize(df_input):
    # source : 
    # https://github.com/ShawnLYU/Quantile_Normalize/blob/master/quantile_norm.py
    df = df_input.copy()
    #compute rank
    dic = {}
    for col in df:
        dic.update({col : sorted(df[col])})
    sorted_df = pandas.DataFrame(dic)
    rank = sorted_df.mean(axis = 1).tolist()
    #sort
    for col in df:
        t = numpy.searchsorted(numpy.sort(df[col]), df[col])
        df[col] = [rank[i] for i in t]
    return df

def normalize_lrr_values(lrr_vals,
                         expectedmean=0, penalty=60, quantile=False, 
                         qspline=False, sd=0.18, recenter=False,
                         n_samplings=50):

    # https://github.com/mbertalan/iPsychCNV/blob/master/R/NormalizeData.R
    if recenter == True:
        lrr_vals = lrr_vals - numpy.mean(lrr_vals[numpy.isnan(lrr_vals)==False])
        print(lrr_vals)
    # add later data that LRR with 2 peaks.
    # detrend the data, only when sd is high & sd(LRR) > 0.2
    if qspline==True:
        """
        			Spline <- smooth.spline(LRR, penalty=penalty)
			Mean <- Spline$y
			LRR <- LRR - Mean
			subSample$Log.R.Ratio <- LRR
        """
        
        # UnivariateSpline(lrr_vals)

    if quantile==True: 

        # Same distribution, fixed sd and mean
        
        """        
        # Creating perfect data
        M <- sapply(1:50, function(N){ rnorm(n=length(LRR), mean=0, sd=sd) })
        M2 <- cbind(M, LRR)
        M3 <- normalize.quantiles(M2)
        LRR <- M3[, 51]
        subSample$Log.R.Ratio <- LRR
        """
        M = numpy.random.normal(loc=0, 
                                scale=sd, 
                                size=[len(lrr_vals), n_samplings+1])
        print(M.shape)
        print(len(lrr_vals))
        M[:, n_samplings] = lrr_vals
        
        # apply quantile normalization to data
        qt = QuantileTransformer(n_quantiles=10000, 
                                 output_distribution="normal",
                                 random_state=0)
        M_q = qt.fit_transform(M)
        lrr_vals_q = M_q[:, n_samplings]
        lrr_vals = lrr_vals_q
    
    return(lrr_vals)
        
if __name__ == "__main__":
    main()
