#!/usr/bin/env python3

import h5py
import numpy as np
import argparse
import os
import sys
import gzip
from collections import OrderedDict

def main():
    args = argparse.ArgumentParser()
    args.add_argument("--hdf5-write-mode", action="store", type=str,
                      default="w", choices=['w','a'],
                      help="file write mode ('w' for write, or 'a' for " + \
                           "append). 'Write' mode will overwrite existing " + \
                           "file, while 'append' mode will add data to " + \
                           "already-existing file.")
    args.add_argument("--chroms", action="store", type=str, 
                      default="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22",
                      help="set of chromosomes to store data for.")
    args.add_argument("--intensity-file-snpid-colnum", action="store", type=int,
                      default=1,
                      help="column number in intensity files for SNP ID.")
    args.add_argument("--intensity-file-baf-colnum", action="store", type=int,
                      default=4,
                      help="column number in intensity files for B allele freq.")
    args.add_argument("--intensity-file-lrr-colnum", action="store", type=int,
                      default=5,
                      help="column number in intensity files for log R ratio.")
    args.add_argument("--intensity-file-suffix", action="store", type=str,
                      default=".txt.adjusted.gz",
                      help="suffix that intensity files end with, for removal.")
    args.add_argument("--intensity-filename-delim",action="store",type=str,
                      default=".", 
                      help="delimiter of intensity filename, for extracting " +\
                           "sampleid.")
    args.add_argument("--intensity-filename-sampleid-pos",action="store",type=int,
                      default=3, 
                      help="position of sampleid in intensity filename, " + \
                           "delim by --intensity-filename-delim.")
    args.add_argument("--gzip-bed", action="store_true", default=False,
                      help="gzip the output BED file.")
    args.add_argument("--stop-loading-after-n-samples", default=None, type=int,
                      help="Stop loading after n samples. For testing purposes.")
    args.add_argument("out_hdf5", type=str,
                      help="output hdf5 file to create.")
    args.add_argument("datasetgroup", type=str,
                      help="name of group of datasets.")
    args.add_argument("bim_file", type=str,
                      help="whitespace-delimited file with snpid/chrom/pos per snp marker in data.")
    args.add_argument("dataset_intensity_rootdirs", 
                      type=str, nargs="+",
                      help="combos of dataset:intensity_file/root/dir.")
    args = args.parse_args()

    """
    define set of chromosomes to extract data for
    """
    chroms_set = set(args.chroms.split(","))

    """
    Read bim file. Get the following (autosomal markers only):
    SNPID -> index mapping
    total number of autosomal SNPs 
    list of lists : [chrom, pos, SNPID, idx] (bed format)
    """
    snpids = []
    snpid_idx = dict()
    snpid_loci = []
    in_fh = open(args.bim_file, "r")
    idx = 0
    chrom_start_idx = OrderedDict()
    chrom_end_idx = OrderedDict()
    for line in in_fh:
        data =line.rstrip().split("\t")
        chrom = data[0]
        snpid = data[1]
        pos = data[3]
        if chrom not in chroms_set: continue
        if chrom not in chrom_start_idx:
            chrom_start_idx[chrom] = idx
            chrom_end_idx[chrom] = idx
        else:
            chrom_end_idx[chrom] = idx
        snpids.append(snpid)
        snpid_idx[snpid] = idx
        snpid_locus = [chrom, pos, snpid]
        snpid_locus_line="\t".join(snpid_locus)
        snpid_loci.append(snpid_locus)
        idx += 1
    n_markers = idx
    in_fh.close()
        
    """
    init HDF5 file, add dataset group data
    """
    hdf5_fh = h5py.File(args.out_hdf5, args.hdf5_write_mode)
    dt = h5py.special_dtype(vlen=str)
    snpid_loci_np = np.array(snpid_loci, dtype=dt)
    hdf5_fh.create_dataset(args.datasetgroup + "/LOCI",
                           data = snpid_loci_np)

    """
    add start and end index for each chromosome
    """
    nchr=len(chrom_start_idx)
    chrom_start_end_idxs=[]
    for chrom in chrom_start_idx.keys():
        chrom_start_end_idxs.append([str(chrom), 
                                     str(chrom_start_idx[chrom]),
                                     str(chrom_end_idx[chrom])])
    dt = h5py.special_dtype(vlen=str)
    chrom_start_end_idxs_np = np.array(chrom_start_end_idxs, dtype=dt)
    hdf5_fh.create_dataset(args.datasetgroup + "/CHROM_IDX",
                           data = chrom_start_end_idxs_np)

    """
    pull down all intensity file paths from intensity rootdir using os.walk.
    for each adjusted.gz intensity file:
    1. init sample node in HDF5
    2. attach to it :
       1. BAF vector, with SNP index identical to bim file
       2. LRR vector, with SNP index identical to bim file
    
    """
    for dataset_intensity_rootdir_str in args.dataset_intensity_rootdirs:
        dataset_intensity_rootdir = dataset_intensity_rootdir_str.split(":")
        dataset = dataset_intensity_rootdir[0]
        intensity_rootdir = dataset_intensity_rootdir[1]
        samplenum = 0
        sampleids_processed = set()
        for root, dirs, files in os.walk(intensity_rootdir):
            for filex in files:

                # increment samplenum 
                samplenum += 1 

                # skip file if not ending with intensity file suffix
                if filex.find(args.intensity_file_suffix) == -1:
                    continue

                # get sampleid by splitting filename str on ".", and 
                # taking sampleid as user-def intensity-filename-sampleid-pos
                filex_list = filex.split(args.intensity_filename_delim)
                sampleid = filex_list[args.intensity_filename_sampleid_pos - 1]
                print(sampleid)

                # skip sampleid if already processed
                if sampleid in sampleids_processed: 
                    continue

                # init baf and lrr datasets for sample, pointing to float vectors
                # of length 'n_markers'
                # dset_root = dataset + "/" + sampleid
                # dset_baf = hdf5_fh.create_dataset(dset_root+"/baf", 
                #                                   (n_markers,), dtype='f')
                # dset_lrr = hdf5_fh.create_dataset(dset_root+"/lrr", 
                #                                   (n_markers,), dtype='f')

                filepath = root + "/" + filex
                if filex.find(".gz") != -1:
                    intensity_fh = gzip.open(filepath, "rb")
                else:
                    intensity_fh = open(filepath, "r")
                print(filepath)
                header = intensity_fh.readline()
                baf_lrr_vals_dict = dict()

                for line in intensity_fh:
                    data = line.rstrip().split("\t")
                    snpid = data[args.intensity_file_snpid_colnum - 1]
                    baf = data[args.intensity_file_baf_colnum - 1]
                    lrr = data[args.intensity_file_lrr_colnum - 1]

                    # store baf/lrr values to dit
                    baf_lrr_vals_dict[snpid] = (float(baf), float(lrr))

                # from dict, form list of baf and lrr values, ordered the 
                # same way that qual SNPIDs are in bim file
                baf_list = []
                lrr_list = []
                baf_arr = np.zeros((n_markers,))
                lrr_arr = np.zeros((n_markers,))
                i=0
                for snpid in snpids:
                    baf_lrr_vals = baf_lrr_vals_dict[snpid]
                    baf_arr[i] = baf_lrr_vals[0]
                    lrr_arr[i] = baf_lrr_vals[1]
                    i += 1

                # add to hdf5
                dset_root = "/".join([args.datasetgroup,
                                      dataset,
                                      sampleid])
                dset_baf = hdf5_fh.create_dataset(dset_root+"/baf", 
                                                  compression="gzip",
                                                  data = baf_arr)
                dset_lrr = hdf5_fh.create_dataset(dset_root+"/lrr", 
                                                  compression="gzip",
                                                  data = lrr_arr)


                # close intensity filehandle
                intensity_fh.close()

                # add to sampleids processed
                sampleids_processed.add(sampleid)

                ## TEST
                if args.stop_loading_after_n_samples != None:
                    if samplenum >= args.stop_loading_after_n_samples:
                        break

            # if break not encountered, the else statement below continues
            # the nested for loop
            else:
                continue
            # if break is encountered, the additional break below is applied
            # to the outer for loop to continue to next input dataset
            break


    """
    close hdf5 filehandle
    """
    hdf5_fh.close()

    return

if __name__ == "__main__":
    main()
