
import argparse
import h5py
import pybedtools
import numpy
import pandas
import matplotlib.pyplot as plt
import os
import sys

def main(userargs):
    # get user args
    parser = argparse.ArgumentParser(prog='plot_cnv_loci',
                                     description='plot BAF and LRR values ' +\
                                                 'around sample-level ' +\
                                                 'call loci.')
    parser.add_argument('--cnvtype', action='store', type=str, 
                        default=None,
                        choices=("DEL","DUP"),
                        help="type of CNVs to make plots for, default %(default)s.")
    parser.add_argument('--length-min', action='store', type=int,
                        default=30000,
                        help="minimum allowed size for CNV calls, default %(default)s.")
    parser.add_argument('--length-max', action='store', type=int,
                        default=20000000,
                        help="maximum allowed size for CNV calls, default %(default)s.")
    parser.add_argument("--numsnp-min", action='store', type=int,
                        default=15,
                        help="minimum number of snps required for CNV calls, default %(default)s.")
    parser.add_argument("--numsnp-max", action='store', type=int,
                        default=float("inf"),
                        help="maximum number of snps required for CNV calls, default %(default)s.")
    parser.add_argument("--chr-include", action='store', type=str,
                        default="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,X,Y,MT",
                        help="chromosomes to subset on, default %(default)s.")
    parser.add_argument("--output-as-bed", action='store_true', default=False,
                        help="write output in BED format, default %(default)s")
    parser.add_argument("--datasets-subset", action="store", type=str,
                        default=None,
                        help="comma-delim datasets to subset case/control comparison on.")
    parser.add_argument('--recip-overlap-min', type=float, help='Minimum reciprocal overlap ' + 
                        'between test locus and CNV to count as an overlap. ' +
                        '[default: %(default)s]', default=0.5)
    parser.add_argument('--in-locus', type=str, default=None,
                        help="input locus to subset all images on " + \
                             "(format : chrom:start-end)")
    parser.add_argument('--locus-padding-fraction', type=float, default=1.0,
                        help="amount to extend endpoints of image beyond " +\
                             "CNV endpoints, defined as " + \
                             "(start-(CNVlength*X), end+(CNVlength*X))")
    parser.add_argument('--plot-width-height', type=str, default="10,5",
                        help=" plot width,height in inches [default %(default)s]")
    parser.add_argument('--plot-files-outroot', type=str, default=None,
                        help='output root name for plot files. If set as None, ' + \
                             'then the matplotlib interactive mode is used '+\
                             'instead.')
    parser.add_argument('--plot-files-type', type=str, default="pdf",
                        help='filetype extension for output plot files.')
    parser.add_argument("--baf-plot-matplotlib-fmt", type=str, default=".g",
                        help="matplotlib format string for baf plot " + \
                             "[default %(default)s]")
    parser.add_argument("--lrr-plot-matplotlib-fmt", type=str, default=".b",
                        help="matplotlib format string for baf plot " + \
                             "[default %(default)s]")
    parser.add_argument('--hdf5-baf-key', type=str, default='baf',
                        help='key for B allele freq in input HDF5 file.')
    parser.add_argument('--hdf5-lrr-key', type=str, default='lrr',
                        help='key for Log R Ratio in input HDF5 file.')
    parser.add_argument('--huang-2017-color-formatting', default=False,
                        action='store_true', 
                        help='color coding of figures in the style of ' + \
                             'Huang et al. 2017 PGC TS CNV paper, ' + \
                             'Figures S5 and S6')
    parser.add_argument('--matplotlib-use', type=str, default=None,
                        help='non-default engine for matplotlib to use.')
    parser.add_argument("in_hdf5", type=str,
                        help="input hdf5 file with format array_group/LOCI,dataset/baf,lrr")
    parser.add_argument("in_group_dataset_iid_tsv", type=str,
                        help="input tab-delim file with group, dataset and IID.")
    parser.add_argument("in_fam", type=str,
                        help="input PLINK fam file, which has pheno in it.")
    parser.add_argument("datasetgroup", type=str,
                        help="name of input dataset group.")
    parser.add_argument("in_cnv_bed", type=str,
                        help="name of cnv BED file.")
    args = parser.parse_args()

    # reload matplotlib using user-defined engine if need be
    if args.matplotlib_use != None:
        import matplotlib
        matplotlib.use(args.matplotlib_use)
        import matplotlib.pyplot as plt

    # load group/dataset information for each iid
    iid_group_dataset = dict()
    in_fh = open(args.in_group_dataset_iid_tsv, "r")
    for line in in_fh:
        data = line.rstrip().split()
        group = data[0]
        dataset = data[1]
        iid = data[2]
        iid_group_dataset[iid] = (group, dataset)
    in_fh.close()

    # load pheno and sex information for each iid
    iid_pheno_sex = dict()
    case_iids = set()
    ctrl_iids = set()
    in_fh = open(args.in_fam, "r")
    for line in in_fh:
        data = line.rstrip().split()
        iid = data[1]
        sex = int(data[4])
        phe = int(data[5])
        iid_pheno_sex[iid] = (phe, sex)

        # get group and dataset, skip if undef
        if iid not in iid_group_dataset: continue
        (group, dataset) = iid_group_dataset[iid]
        if args.datasets_subset != None:
            datasets_subset = set(args.datasets_subset.split(","))
            if dataset not in datasets_subset: continue

        if phe == 2:
            case_iids.add(iid)
        elif phe == 1:
            ctrl_iids.add(iid)
    in_fh.close()

    # define n cases and n controls based on data in famfile
    n_ca = len(case_iids)
    n_co = len(ctrl_iids)

    # init filehandle for reading input HDF5
    hdf5_fh = h5py.File(args.in_hdf5, "r")

    # if defined, parse user-defined locus and form BedTool obj
    if args.in_locus != None:
        chrom_startend = args.in_locus.split(":")
        chrom = chrom_startend[0]
        chrom = chrom.replace('chr','')
        startend_str = chrom_startend[1]
        startend = startend_str.split("-")
        start = str(int(startend[0]) - 1)
        end = startend[1]
        locus = pybedtools.BedTool("\t".join([chrom, start, end]),
                                   from_string=True)
        print(locus)

    # get marker loci, intersect with loci
    marker_loci = hdf5_fh[args.datasetgroup]['LOCI']
    marker_loci_bed = []
    marker_loci_idx = dict()

    # get chrom, end, markerid, and format to str lists
    n_markers = marker_loci.shape[0]
    chroms = numpy.array(marker_loci[:, 0])
    ends = numpy.array(marker_loci[:, 1])
    markerids = numpy.array(marker_loci[:, 2])
    chroms = [str(x, 'ascii') for x in chroms]
    ends = [str(x, 'ascii') for x in ends] 
    markerids = [str(x, 'ascii') for x in markerids]

    # turn ordered marker list into list of lists in BED file format
    print("loading BED .. ")
    for i in range(n_markers):
        chrom = chroms[i]
        chrom = chrom.replace("chr","")
        end = ends[i]
        start0 = str(int(end) - 1)
        markerid = markerids[i]
        marker_loci_idx[markerid] = i
        marker_loci_bed.append("\t".join([chrom, start0, end, 
                                          markerid, str(i)]))
    print("done.")
    
    # create instance of pybedtools for marker_loci
    marker_loci_pbt = pybedtools.BedTool("\n".join(marker_loci_bed),
                                         from_string=True)

    # create instance of pybedtools.BedTool for input CNVs
    cnvs = pybedtools.BedTool(args.in_cnv_bed)

    # if defined, parse user-defined locus and form BedTool obj
    if args.in_locus != None:
        chrom_startend = args.in_locus.split(":")
        chrom = chrom_startend[0]
        chrom = chrom.replace("chr","")
        startend_str = chrom_startend[1]
        startend = startend_str.split("-")
        locus_start = str(int(startend[0]) - 1)
        locus_end = startend[1]
        locus_pbt = pybedtools.BedTool("\t".join([chrom, locus_start, 
locus_end]),
                                       from_string=True)
    
        # subset cnvs on locus
        cnvs = cnvs.intersect(locus_pbt, wa=True, u=True)

    # for each cnv ..
    for cnv_x in cnvs:

        # get cnv type, skip if user defined a specific cnvtype and this one
        # does not meet specification
        cnvtype = cnv_x[4]
        if args.cnvtype != None:
            if cnvtype != args.cnvtype:
                continue

        # get carrier iid
        carrier_iid = cnv_x[5]

        # get source datasetgroup and dataset for sample
        (group, dataset) = iid_group_dataset[carrier_iid]

        # get carrier iid phenotype and sex
        (pheno_x, sex_x) = iid_pheno_sex[carrier_iid]
        pheno_str_x = "unknown_phe"
        if pheno_x == 2:
            pheno_str_x = "CASE"
        if pheno_x == 1:
            pheno_str_x = "CTRL"

        # get chromosome, start, end positions
        chrom_x = cnv_x[0]
        chrom_x = chrom_x.replace("chr","")
        start0_x = cnv_x[1]
        end_x = cnv_x[2]

        # get cnv positions
        start0_x_pos = int(start0_x)
        end_x_pos = int(end_x)

        # get cnv type
        cnv_type_x = cnv_x[4]

        # get all markers that map to chromosome
        interval_x = "\t".join([chrom_x, "0","300000000"])
        interval_x_pbt = pybedtools.BedTool(interval_x, from_string=True)
        marker_loci_pbt_x = marker_loci_pbt.intersect(interval_x_pbt,
                                                      wa=True, u=True)

        # get start and end idx for values on chromosome
        start_idx = int(marker_loci_pbt_x[0][4])
        end_idx = int(marker_loci_pbt_x[len(marker_loci_pbt_x)-1][4])
        marker_loci_x = marker_loci[start_idx:end_idx,:]

        # extract values for sample relevant to the call region
        lrr_key = args.hdf5_lrr_key
        baf_key = args.hdf5_baf_key
        datasetgroup = iid_group_dataset[carrier_iid][0]
        dataset = iid_group_dataset[carrier_iid][1]

        # skip sample if not in user-specified dataset group
        if datasetgroup != args.datasetgroup: continue

        # get lrr and baf values for plotting, along with marker positions
        lrr_vals = hdf5_fh[datasetgroup][dataset][carrier_iid][lrr_key][start_idx:end_idx]
        baf_vals = hdf5_fh[datasetgroup][dataset][carrier_iid][baf_key][start_idx:end_idx]
        positions = marker_loci[start_idx:end_idx,1].astype(int)

        # get all positions and lrr/baf values in cnv locus
        interval_x = "\t".join([chrom_x, str(start0_x_pos),str(end_x_pos)])
        interval_x_pbt = pybedtools.BedTool(interval_x, from_string=True)
        marker_loci_pbt_x = marker_loci_pbt.intersect(interval_x_pbt,
                                                      wa=True, u=True)
        start_idx = int(marker_loci_pbt_x[0][4])
        end_idx = int(marker_loci_pbt_x[len(marker_loci_pbt_x)-1][4])
        marker_loci_x = marker_loci[start_idx:end_idx,:]
        lrr_vals_cnv = hdf5_fh[datasetgroup][dataset][carrier_iid][lrr_key][start_idx:end_idx]
        baf_vals_cnv = hdf5_fh[datasetgroup][dataset][carrier_iid][baf_key][start_idx:end_idx]
        positions_cnv = marker_loci[start_idx:end_idx,1].astype(int)

        # init the output plot
        fig_width_height = args.plot_width_height.split(",")
        fig_width = int(fig_width_height[0])
        fig_height = int(fig_width_height[1])
        fig, axs = plt.subplots(2, sharex=True,
                                figsize=(fig_width, fig_height))

        # adjust spacing
        fig.tight_layout()

        # define x limits
        cnv_length = int(end_x) - int(start0_x)
        xstart = int(start0_x) - (cnv_length * args.locus_padding_fraction)
        xend = int(end_x) + (cnv_length * args.locus_padding_fraction)

        # init table
        cnv_str = " - ".join([group, dataset, carrier_iid]) + \
                  " (" + pheno_str_x + ") | "+\
                  chrom_x+":"+start0_x+"-"+end_x+"_"+cnv_type_x
        fig.suptitle(cnv_str, y=0.995)

        # make top subplot (LRR values vs position)
        call_colors = {"DEL":"red", "DUP":"blue"}
        if args.huang_2017_color_formatting == False:
            axs[0].plot(positions, lrr_vals, args.lrr_plot_matplotlib_fmt)
        else:
            axs[0].plot(positions, lrr_vals, '.', color='silver')
            axs[0].plot(positions_cnv, lrr_vals_cnv, 
                        '.', color=call_colors[cnvtype]) 
        axs[0].set_ylim(-1, 1)
        axs[0].set_xlim(xstart, xend)
        axs[0].axvline(x=int(start0_x), color='black')
        axs[0].axvline(x=int(end_x), color='black')
        axs[0].axhline(y=0, color='black')

        # make bottom subplot (BAF values vs position)
        if args.huang_2017_color_formatting == False:
            axs[1].plot(positions, baf_vals, args.baf_plot_matplotlib_fmt)
        else:
            axs[1].plot(positions, baf_vals, '.', color='silver')
            axs[1].plot(positions_cnv, baf_vals_cnv, 
                        '.', color=call_colors[cnvtype])
        axs[1].set_xlim(xstart, xend)
        axs[1].axvline(x=int(start0_x), color='black')
        axs[1].axvline(x=int(end_x), color='black')
        axs[1].axhline(y=0.5, color='black')

        # adjust bottom of figure
        fig.subplots_adjust(bottom=0.015)

        # if locus defined, set as xlim, else use cnv endpoints padded
        if args.in_locus != None:
            axs[0].set_xlim(int(locus_start), int(locus_end))
            axs[1].set_xlim(int(locus_start), int(locus_end)) 
        else:
            axs[0].set_xlim(xstart, xend)
            axs[1].set_xlim(xstart, xend)
        
        # remove x axis ticks
        axs[0].set_xticks([])
        axs[1].set_xticks([])

        # if desired by user, save plot. Otherwise, send to interactive plot
        if args.plot_files_outroot != None:
            plt.savefig(".".join([args.plot_files_outroot,
                                  group, dataset, carrier_iid,
                                  chrom_x+"_"+start0_x+"_"+end_x+"_"+cnv_type_x,
                                  args.plot_files_type]))
        else:
            plt.show()

        # close current figure window
        plt.close()

    # close hdf5 filehandle
    hdf5_fh.close()
    
    return

if __name__ == "__main__":
    userargs = sys.argv[1:]
    main(userargs)
