'''
©2024, Francisco C. Marques, Institute for Biomechanics, ETH Zurich
'''

import h5py
import os
import multiprocessing

import numpy as np
import pandas as pd

from functools import partial
from PIL import Image
from tqdm import tqdm
from readlif.reader import LifFile

from scipy.io import loadmat
from scipy.ndimage import gaussian_filter
from scipy.optimize import minimize
from skimage.color import rgb2gray
from skimage.exposure import adjust_sigmoid, equalize_adapthist
from skimage.transform import rescale


# ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ###
# Preprocessing functions
# ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ###


def optimize_sigmoid_cutoff_value(cutoff: float,
                                  image: np.ndarray,
                                  gain: float):
    """
    Optimize the sigmoid cutoff value for image enhancement.

    Parameters
    ----------
    cutoff : float
        The cutoff value for the sigmoid function. Should be between 0 and 1.
    image : ndarray
        The input image.
    gain : float
        The gain parameter for the sigmoid function.

    Returns
    -------
    float
        The optimized sigmoid cutoff value.

    Raises
    ------
    ValueError
        If the cutoff value is not between 0 and 1.

    Notes
    -----
    This function adjusts the sigmoid cutoff value to enhance the image.
    It calculates the minimum and maximum values of the adjusted image
    and returns the absolute square root of the sum of their squares.

    """

    if (cutoff < 0) or (cutoff > 1):
        raise ValueError("The cutoff value should be between 0 and 1.")

    clage_sig_img = adjust_sigmoid(image, gain=gain, cutoff=cutoff)
    min_img, max_img = np.amin(clage_sig_img), np.amax(clage_sig_img)

    return np.absolute(np.sqrt(np.square(1-max_img) + np.square(min_img)))


def apply_brightness_correction(input_image: np.ndarray,
                                bins: np.ndarray = None,
                                saturation_value: float = 0.35,
                                default_intensity: int = 200,
                                max_intensity: int = 255):
    """
    Apply (Fiji-like) brightness correction to the input image.
    Adapted from: https://github.com/imagej/ImageJ/blob/5910a5b600d6879928b2b87468d8f12cd35c32bb/ij/plugin/frame/ContrastAdjuster.java#L816

    Parameters
    ----------
    input_image : ndarray
        The input image.

    bins : ndarray, optional
        The bins for the histogram. If not provided, the default bins for uint8 range (0 to 255) will be used.

    saturation_value : float, optional
        The saturation value used to calculate the threshold. Default is 0.35.

    default_intensity : int, optional
        The default intensity value used to calculate the threshold. Default is 200.

    max_intensity : int, optional
        The maximum intensity value. Default is 255.

    Returns
    -------
    ndarray
        The corrected image.
    """

    isRGB = True if len(input_image.shape) == 3 else False
    if len(input_image.shape) == 2:
        input_image = np.expand_dims(input_image, -1)

    if bins is None:
        bins = np.arange(0, 256, 1)

    if isRGB:
        norm_img = rgb2gray(input_image/max_intensity) * max_intensity
    else:
        norm_img = input_image

    hist_img = np.histogram(norm_img.ravel(), bins=bins)[0]

    total_pixel_count = np.sum(hist_img)
    threshold = int(total_pixel_count * saturation_value / default_intensity)

    min_thresh = np.argmin(np.cumsum(hist_img) < threshold)
    max_thresh = hist_img.size - np.argmin(np.cumsum(np.flip(hist_img)) < threshold) - 1

    corrected_image = np.dstack([np.clip(input_image[..., i], min_thresh, max_thresh)/max_thresh for i in range(input_image.shape[2])]) * hist_img.size

    return corrected_image.astype(int)


def apply_sliding_window_bright_corr(z_stack: np.ndarray,
                                     bright_corr_patch_window_size: int,
                                     bright_corr_patch_stride_size: int,
                                     default_intensity: int = 2,
                                     saturation_value: float = 0.35):
    """
    Apply sliding window brightness correction (like Fiji) algorithm to a stack of images.

    Parameters
    ----------
    z_stack : ndarray
        Stack of images.

    bright_corr_patch_window_size : int
        Size of the patch side.

    bright_corr_patch_stride_size : int
        Stride of the patch.

    default_intensity : int, optional
        Default intensity value, by default 2.

    saturation_value : float, optional
        Saturation value, by default 0.35.

    Returns
    -------
    ndarray
        Average stack of images after applying the sliding window brightness correction (like Fiji) algorithm.
    """

    img_shape = z_stack[0].shape
    patch_corners = np.array(np.meshgrid(*[np.concatenate([np.arange(0, i, bright_corr_patch_stride_size), (i, )]) for i in img_shape[:2]])).reshape(2, -1).T

    avg_z_stack = np.zeros((*z_stack[0].shape[:2], len(z_stack)))
    for img_idx, image in enumerate(z_stack):
        patch_corrected_img = np.zeros(img_shape)
        patch_corrected_count = np.zeros(img_shape[:2])

        for patch_coord in patch_corners:
            sample_img = image[patch_coord[0]:patch_coord[0]+bright_corr_patch_window_size, patch_coord[1]:patch_coord[1]+bright_corr_patch_window_size]

            if 0 in sample_img.shape:
                continue

            output = apply_brightness_correction(sample_img, default_intensity=default_intensity, saturation_value=saturation_value)

            patch_corrected_img[patch_coord[0]:patch_coord[0]+bright_corr_patch_window_size, patch_coord[1]:patch_coord[1]+bright_corr_patch_window_size] += np.squeeze(output).clip(0, 255)
            patch_corrected_count[patch_coord[0]:patch_coord[0]+bright_corr_patch_window_size, patch_coord[1]:patch_coord[1]+bright_corr_patch_window_size] += 1
        avg_z_stack[..., img_idx] = np.divide(patch_corrected_img, patch_corrected_count, where=patch_corrected_count > 0)

    return avg_z_stack


def single_section_processing(input_image: np.ndarray,
                              img_pixel_sizes: list,
                              bright_corr_patch_window_size: int = 200,
                              bright_corr_patch_stride_size: int = 20,
                              clahe_kernel_size: int = 31):
    """
    Function to process a single section of an image.

    Parameters
    ----------
    input_image : np.ndarray
        The input image.

    img_pixel_sizes : np.ndarray
        The pixel sizes of the image (one value per axis).

    bright_corr_patch_window_size : int, optional
        The size of the patch window. Default is 200.

    bright_corr_patch_stride_size : int, optional
        The stride of the patch window. Default is 20.

    clahe_kernel_size : int, optional
        The size of the kernel for adaptive histogram equalization. Default is 31.

    Returns
    -------
    np.ndarray
        The processed image.
    """

    intensity_hist, _ = np.histogram(input_image.ravel(), bins=np.arange(0, 256, 1))
    min_threshold = np.argmax(intensity_hist == np.amin(intensity_hist)) / 255
    default_intens = np.argmax(intensity_hist) if np.argmax(intensity_hist) > 1 else 5

    bright_corr_correct = apply_sliding_window_bright_corr([input_image],
                                                           bright_corr_patch_window_size,
                                                           bright_corr_patch_stride_size,
                                                           default_intensity=default_intens,
                                                           saturation_value=min_threshold)

    # Original code considered a sigma of 0.65 for a pixel size of 0.2 um
    sigma_gfilt = (0.65 * img_pixel_sizes[0])/0.2
    gfilt_section = gaussian_filter(bright_corr_correct, sigma=sigma_gfilt, truncate=5)
    norm_gfilt_section = gfilt_section / np.amax(gfilt_section)

    clahe_section = equalize_adapthist(norm_gfilt_section, kernel_size=clahe_kernel_size)

    out_cutoff = minimize(optimize_sigmoid_cutoff_value, 0.5, args=(clahe_section, 10))
    clage_sig_img = adjust_sigmoid(clahe_section, gain=10, cutoff=out_cutoff.x[0])

    return clage_sig_img


def read_image_from_LIF(image_path: str,
                        image_index: int):
    '''
    Read an image from a Leica Image Format (LIF) file and extract pixel sizes.

    Parameters
    ----------
    image_path : str
        Path of the LIF file containing the desired image.

    image_index : int
        Index of the image within the LIF file.

    Returns
    -------
    z_stack : np.ndarray
       A stacked array representing the Z dimension of the image.

    img_pixel_sizes : np.ndarray
          An array containing absolute values of pixel size along XYZ axes.

    Notes
    -----
    This function uses `readlif` library to open and process images stored in LIF format.
    '''

    # Open the LIF file object
    lif_object = LifFile(image_path)

    # Extract pixel sizes and images
    img_pixel_sizes = np.absolute(1 / np.array(lif_object.get_image(image_index).scale[:3]))
    z_stack = np.dstack([np.array(i) for i in lif_object.get_image(image_index).get_iter_z(t=0, c=0)])

    return z_stack, img_pixel_sizes


def process_image_stack(image_stack: np.ndarray,
                        img_pixel_sizes: list,
                        number_of_processes: int = 42,
                        default_kernel_bright_corr: int = 50,
                        default_stride_bright_corr: int = 8,
                        default_kernel_clahe: int = 6):
    '''
    Function to process a stack of images using multiprocessing.

    Parameters
    ----------
    image_stack : np.ndarray
        A stack of images to be processed. The shape of the stack should be (n, x, y), where n is the number of images.

    img_pixel_sizes : list
        A list of pixel sizes for the images in the stack. The list should contain one value per axes (X, Y, Z).

    number_of_processes : int, optional
        The number of processes to use for parallel processing.
        Default is 42.

    default_kernel_bright_corr : int, optional
        The default kernel size for brightness correction processing.
        Default is 50 um.

    default_stride_bright_corr : int, optional
        The default stride size for brightness correction processing.
        Default is 8 um.

    default_kernel_clahe : int, optional
        The default kernel size for CLAHE processing.
        Default is 6 um.

    Returns
    -------
    np.ndarray
        Processed stack of images.
    '''

    processed_stacks = []
    kernel_bright_corr = int(np.round(default_kernel_bright_corr / img_pixel_sizes[0]))
    stride_bright_corr = int(np.round(default_stride_bright_corr / img_pixel_sizes[0]))
    kernel_clahe = int(np.round(default_kernel_clahe / img_pixel_sizes[0])) + 1

    # Create a partial function with img_pixel_sizes as a fixed argument
    partial_single_section_processing = partial(single_section_processing,
                                                img_pixel_sizes=img_pixel_sizes,
                                                bright_corr_patch_window_size=kernel_bright_corr,
                                                bright_corr_patch_stride_size=stride_bright_corr,
                                                clahe_kernel_size=kernel_clahe)

    pool = multiprocessing.Pool(processes=number_of_processes)
    processed_stacks = list(tqdm(pool.imap(partial_single_section_processing, image_stack), total=len(image_stack)))
    pool.close()

    return np.squeeze(np.array(processed_stacks))


def process_LIF_image(lif_object,
                      target_pixel_size: float = 0.4,
                      number_of_processes: int = 2,
                      default_kernel_bright_corr: int = 50,
                      default_stride_bright_corr: int = 8,
                      default_kernel_clahe: int = 6,
                      image_name_filter: list = None,
                      channel_axes: int = 2):
    """
    Processes an image stack using the given LIF object and target pixel size.

    Parameters
    ----------
    lif_object : LIF object
        The LIF object containing the image stack.

    target_pixel_size : float, optional
        The desired output pixel size in micrometers. Assumes isotropic scaling.
        Default is 0.4.

    number_of_processes : int, optional
        The number of processes to use for parallel processing.
        Default is 2.

    default_kernel_bright_corr : int, optional
        The default kernel size for brightness correction processing.
        Default is 50 um.

    default_stride_bright_corr : int, optional
        The default stride size for brightness correction processing.
        Default is 8 um.

    default_kernel_clahe : int, optional
        The default kernel size for CLAHE processing.
        Default is 6 um.

    image_name_filter : list, optional
        List with image names to be analysed. If an image in the LIF file is not in this filter
        it is skipped.

    Returns
    -------
    img_stack : dict
        A dictionary containing the original image stacks.
        The keys represent the image names, and the values represent the image stacks.

    img_preprocessed : dict
        A dictionary containing the preprocessed image stacks.
        The keys represent the image names, and the values represent the preprocessed image stacks.

    scaled_img_preprocessed : dict
        A dictionary containing the rescaled and preprocessed image stacks.
        The keys represent the image names, and the values represent the rescaled and preprocessed image stacks.

    Raises
    ------
    None
    """

    img_idx_pair = {lif_object.image_list[i]['name']: i for i in range(len(lif_object.image_list))}

    img_stack = {}
    img_preprocessed = {}
    scaled_img_preprocessed = {}

    for img_name, img_idx in img_idx_pair.items():

        if image_name_filter is not None:
            if img_name not in image_name_filter:
                continue

        img_pixel_sizes = np.absolute(1 / np.array(lif_object.get_image(img_idx).scale[:3]))
        # Places the channels in the first axes
        z_stack = np.dstack([np.array(i) for i in lif_object.get_image(img_idx).get_iter_z(t=0, c=0)]).swapaxes(0, 2).swapaxes(1, 2)

        img_stack[img_name] = z_stack

        processed_stacks = process_image_stack(z_stack,
                                               img_pixel_sizes,
                                               number_of_processes,
                                               default_kernel_bright_corr,
                                               default_stride_bright_corr,
                                               default_kernel_clahe)
        img_preprocessed[img_name] = processed_stacks

        scaling_factors = np.absolute(img_pixel_sizes) / target_pixel_size
        scaled_img = rescale(processed_stacks.swapaxes(0, 2).swapaxes(0, 1), scaling_factors, preserve_range=True)
        scaled_img_preprocessed[img_name] = scaled_img

    return img_stack, img_preprocessed, scaled_img_preprocessed


# ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ###
# Export output of preprocessing
# ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ###


def export_threshold_images(scaled_images: dict,
                            export_path: str,
                            threshold_value: float = 0.5):
    """
    Thresholds the images in the given dictionary and saves the thresholded images as tiff.

    Parameters
    ----------
    scaled_images : dict
        A dictionary containing the preprocessed images.
        The keys represent the image names, and the values represent the scaled images.

    threshold_value : float
        The threshold value for binarization.

    export_path : str
        The path to the folder where the thresholded images will be saved.

    Returns
    -------
        None

    Raises
    ------
        None
    """

    for image_name, scaled_image in scaled_images.items():
        save_image_name = image_name.split('-')[np.argmax([len(i) for i in image_name.split('-')])]
        os.makedirs(f'{export_path}/{save_image_name}', exist_ok=True)

        for i in range(scaled_image.shape[2]):
            output_section_filename = f'{export_path}/{save_image_name}/section_thres{str(threshold_value).replace(".", "f")}_{i:03}.tiff'
            thresholded_image = (scaled_image[..., i] > threshold_value) * 255
            thresholded_image = thresholded_image.astype(np.uint8)
            Image.fromarray(thresholded_image).save(output_section_filename)


# ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ###
# Output statistics
# ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ### # ###


def _n_compute_stats(float_values, array_values):

    tree_prct = 100 * array_values['t_nodes'].size / float_values['N']
    clus_prct = 100 * array_values['c_nodes'].size / float_values['N']
    end_prct = 100 * (array_values['e_nodes'].size - array_values['edg'].size) / float_values['N']
    # S = (float_values['acc'] / array_values['acc_ER']) / (float_values['asp_logic'] / float_values['asp_ER'])

    summary_df = pd.DataFrame.from_dict(float_values, orient='index').T
    summary_df['tree_prct'] = tree_prct
    summary_df['clus_prct'] = clus_prct
    summary_df['end_prct'] = end_prct

    return summary_df


def compute_matlab_stats(matlab_output_folders):
    """
    Compute statistics from the given sample folders and project folder.

    Parameters
    ----------
    matlab_output_folders : list
        List of folder paths containing the outputs from Matlab.

    Returns
    -------
    complete_df : pandas.DataFrame
        A DataFrame containing the computed statistics.

    Raises
    ------
    None
    """

    complete_df = pd.DataFrame()

    for sample_region_folder in matlab_output_folders:
        try:
            n_mat = h5py.File(f'{sample_region_folder}/n.mat')['n']
        except Exception:
            print(f'Skipping {sample_region_folder}')
            continue

        cell_mat = loadmat(f'{sample_region_folder}/cell.mat')
        sample_region_folder_split = sample_region_folder.split('/')

        # Typical image name structure: 'Polg 20w M WT 3082 a1'
        image_name = sample_region_folder_split[-1]
        sample_name = image_name
        region_name = image_name.rsplit('-')[0].split(" ")[-1]
        sample_group = image_name.split(" ")[0]

        n_float_keys = ['scale', 'N', 'N_nocell', 'dd', 'lac_por', 'can_len',
                        'd_net', 'd_cell', 'd_can', 'mean_deg', 'median_deg',
                        'mean_deg_ep', 'mean_deg_noedg', 'edge_den',
                        'edge_den_nocell', 'asp', 'acc', 'asp_logic', 'asp_ER']
        n_array_keys = ['t_nodes', 'c_nodes', 'e_nodes', 'edg', 'acc_ER', 'd_ratio3']

        n_float_values = {}
        for float_key in n_float_keys:
            n_float_values[float_key] = float(np.squeeze(np.array(n_mat[float_key])))

        n_arr_values = {}
        for arr_key in n_array_keys:
            n_arr_values[arr_key] = np.squeeze(np.array(n_mat[arr_key]))

        sample_df = _n_compute_stats(n_float_values, n_arr_values)
        sample_df['number_of_cells'] = cell_mat['cell'][0].size
        sample_df['sample_name'] = sample_name
        sample_df['region_name'] = region_name
        sample_df['sample_group'] = sample_group

        complete_df = pd.concat([complete_df, sample_df], axis=0)

    return complete_df
