import numpy as np
import scipy.stats as sstats


def snieder_relax(t: np.ndarray, tau_min: float, tau_max: float, N: int = 50):
    """
    :param t: vector with the time values, in the same units as tau_min and tau_max
    :type t: np.ndarray

    :param tau_min: shortest timescale of recovery processes (observability limited by sampling rate)
    :type tau_min: float

    :param tau_max: longest timescale of recovery processes (full recovery for longer times)
    :type tau_max: float

    :param N: the number of points to use to compute the integral.
    :type N: int

    :return out: vector containing the values of R(t)
    :rtype out: np.ndarray

    This function computes the relaxation function R(t), ie the integral from tau_min to tau_max
    of 1/tau * exp(-t/tau) dtau (as described in Snieder et al., 2017) using the trapezoidal rule.

    It is set to 0 for negative values of time.
    """

    # exclude negative times
    ind = np.where(t > 0)[0]
    vec = np.zeros_like(t[ind], dtype=float)

    # Use logarithmically spaced values for tau and compute the spacing
    logv = np.logspace(np.log10(tau_min), np.log10(tau_max), N)
    a = (tau_max / tau_min) ** (1 / (N - 1))

    # Simplified expression of the integral using the trapezoidal rule
    for tau in logv:
        vec = vec + np.exp(-t[ind] / (a * tau)) / a + np.exp(-t[ind] / (tau))

    out = np.zeros_like(t, dtype=float)
    out[ind] = out[ind] - vec * (a - 1) / 2

    return out


def heaviside(t, flip: bool = False):
    out = np.zeros_like(t).astype(float)  # , dtype=float)

    out[np.where(t > 0)] = 1.0

    if flip:
        out = -out

    return out


def norm_and_detrend(vec):
    out = (vec - np.amin(vec)) / np.max(vec - np.amin(vec))
    out = out - np.mean(out)

    return out


def create_R(t, tjs, tau_min=0.1, tau_max=1.5e4, cumul=True):

    R = np.zeros_like(t, dtype=float)

    if cumul:
        for ti in tjs:
            yt = snieder_relax((t - ti), tau_min, tau_max)
            R = R + yt
    else:
        for ti in tjs:
            if not (ti >= t[-1]):
                ind = np.where(t > ti)
                yt = snieder_relax((t - ti), tau_min, tau_max)
                R[ind] = yt[ind]

    R = norm_and_detrend(R)

    return R


def create_classical(t, tjs):
    classical = np.zeros_like(t, dtype=float)

    flip = False

    for ti in tjs:
        classical = classical + heaviside(t - ti, flip=flip)
        flip = not (flip)

    classical = norm_and_detrend(classical)

    return classical


def create_lin_trend(t):
    out = t / 28792 - np.mean(t / 28792)  # 28792 is the 8-hour mark in seconds
    return out


def create_components(t, tjs, tau_min=0.1, tau_max=1.5e4, cumul=True):
    """
    Generate the three base components for the linear inversion: the linear trend, the classical component
    and the relaxation function, all normalized and detrended.
    """

    R = create_R(t, tjs, tau_min, tau_max, cumul)

    classical = create_classical(t, tjs)

    # Linear trend
    lin_trend = create_lin_trend(t)

    return {"R": R, "classical": classical, "lin_trend": lin_trend}


def cgs(x, eps, N):
    """
    Creates a grid of values centered around x, within a proportional range of eps
    """

    return np.linspace((1 - eps) * x, (1 + eps) * x, N, endpoint=True)


def compute_models(coefs, components, which_coef, eps, N):
    """
    Computes the models for a given set of computed fitting coefficients and the model
    components that go along with it.

    Returns an N x M array of values, where N is the number of different models and M is
    the length of each model vector.
    """

    alpha_opt = coefs[0]
    beta_opt = coefs[1]
    delta_opt = coefs[2]

    L = components["lin_trend"]
    C = components["classical"]
    R = components["R"]

    if which_coef == "alpha":
        search_vec = cgs(alpha_opt, eps, N)

        models = (
            np.outer(search_vec, L)
            + np.outer(beta_opt * np.ones(N), C)
            + np.outer(delta_opt * np.ones(N), R)
        )

    if which_coef == "beta":
        search_vec = cgs(beta_opt, eps, N)

        models = (
            np.outer(alpha_opt * np.ones(N), L)
            + np.outer(search_vec, C)
            + np.outer(delta_opt * np.ones(N), R)
        )

    if which_coef == "delta":
        search_vec = cgs(delta_opt, eps, N)

        models = (
            np.outer(alpha_opt * np.ones(N), L)
            + np.outer(beta_opt * np.ones(N), C)
            + np.outer(search_vec, R)
        )

    opt_model = alpha_opt * L + beta_opt * C + delta_opt * R

    return (models, opt_model, search_vec)


def f_test(ref_sample, sample, alpha_stat=0.05):
    """
    H0: the two samples have the same variance
    H1: the two samples have different variances

    Uses a simple two-tailed F-test to determine whether or not two
    samples are significantly different from each other or not.
    """

    variance_r = np.var(ref_sample, ddof=1)
    variance = np.var(sample, ddof=1)

    # Because variance_r is the variance of the optimal solution,
    # we expect it to be smaller than the variance for any other
    # parameter. So in theory a one-tailed test should be enough.

    f_value = variance_r / variance

    dfr = ref_sample.shape[0] - 1
    df = sample.shape[0] - 1

    # The p-value is the likelihood of H0 being true for this value of the statistic
    p_value = sstats.f.cdf(f_value, dfr, df)

    # Going with a two-tailed test at the moment
    if (p_value < alpha_stat / 2) or (p_value > 1 - alpha_stat / 2):
        # We reject H0, therefore the variances can be considered different
        different_distributions = True
    else:
        # We cannot reject H0, and therefore cannot claim that the variances are significantly different
        different_distributions = False

    return {"res": different_distributions, "F-statistic": f_value, "p-value": p_value}


def compute_confidence_interval(
    which_coef,
    coefs,
    dvval,
    components,
    eps=0.3,
    N=101,
    mpl_compatible=True,
    full_return=False,
):
    """
    Computes the confidence interval for one fitting parameter on one model
    fit.

    For a full return:
        Returns a dictionary containing the bounds of the interval, the full
        interval, the relative error, the scores for the tested values of the
        parameter and the grid of search parameters.
    Otherwise:
        Returns the bounds of the confidence interval
    """

    if which_coef == "alpha":
        x0 = coefs[0]

    elif which_coef == "beta":
        x0 = coefs[1]

    elif which_coef == "delta":
        x0 = coefs[2]

    else:
        raise ValueError("which_coef must be 'alpha', 'beta' or 'delta'")

    test_sum = 0

    while not (np.isnan(test_sum)):
        models, opt_model, search_vec = compute_models(
            coefs, components, which_coef, eps, N
        )

        # Residuals = dvval - models

        opt_residual = dvval - opt_model
        residuals = np.outer(np.ones(N), dvval) - models

        confidence_interval = np.empty(N)
        confidence_interval[:] = np.nan

        statistic_interval = np.empty(N)
        statistic_interval[:] = np.nan

        for k, residual in enumerate(residuals):
            test_res = f_test(opt_residual, residual)

            if not (test_res["res"]):
                confidence_interval[k] = search_vec[k]
                statistic_interval[k] = test_res["F-statistic"]

            else:
                statistic_interval[k] = test_res["F-statistic"]

        test_sum = np.sum(confidence_interval)

        eps = 3 * eps
        N = 3 * N

    bounds = [np.nanmin(confidence_interval), np.nanmax(confidence_interval)]

    if mpl_compatible:
        bounds = np.abs(bounds - x0)

    relative_err = np.abs(((np.nanmax(confidence_interval) - x0) / x0)) * 100

    if full_return:
        return {
            "bounds": bounds,
            "interval": confidence_interval,
            "relative error": relative_err,
            "scores": statistic_interval,
            "grid": search_vec,
        }
    else:
        return bounds


def MC_fit_mean(which_coef, fit_type, mfile, N_MC: int = 5000):
    """
    Do the Monte-Carlo inversion on the mean values of each angle bin
    for a given coefficient and a given fit type (either Linear or Thomsen).
    """

    plotval = mfile[which_coef]
    thetas = mfile["thetas"]

    if which_coef == "alpha":
        ebounds = mfile["alpha_bounds"]

    elif which_coef == "beta":
        ebounds = mfile["beta_bounds"]

    elif which_coef == "delta":
        ebounds = mfile["delta_bounds"]

    angle_vals, bins = np.histogram(thetas, 51, range=(39, 90))

    nbins = angle_vals.shape[0]

    means = np.zeros((nbins,))
    stds = np.zeros((nbins,))
    avg_angles = np.zeros((nbins,))

    arr = plotval

    med_err = np.median(ebounds[1, :])

    for kbin in range(nbins):
        if angle_vals[kbin] == 0:
            continue

        brr = bins[kbin]
        crr = bins[kbin + 1]

        bin_vals = np.intersect1d(arr[thetas > brr], arr[thetas <= crr])
        bin_angles = np.intersect1d(thetas[thetas > brr], thetas[thetas <= crr])

        # Arithmetic mean and std
        means[kbin] = np.average(bin_vals)
        stds[kbin] = (med_err / 2) / np.sqrt(bin_vals.shape[0])

        avg_angles[kbin] = np.average(bin_angles)

    means = means[angle_vals != 0]
    stds = stds[angle_vals != 0]
    avg_angles = avg_angles[angle_vals != 0]
    angles = bins[1:][angle_vals != 0]
    means_cfi = 2 * stds  # 2*sigma contains 95% of values

    # Prepare for Monte-Carlo inversion

    fit_params = np.zeros((2, N_MC))

    means_MC = np.zeros((N_MC, means.shape[0]))

    # Prepare matrix for thomsen inversion
    thomsen_vec = np.cos(avg_angles * np.pi / 180) ** 2
    M = np.column_stack((thomsen_vec, np.ones(thomsen_vec.shape[0])))

    for k in range(means.shape[0]):

        means_MC[:, k] = np.random.uniform(
            means[k] - means_cfi[k], means[k] + means_cfi[k], N_MC
        )
        means_MC[:, k] = np.random.normal(means[k], means_cfi[k] / 2, N_MC)

    for m in range(N_MC):

        if fit_type == "Linear":

            lreg = sstats.linregress(90 - avg_angles, means_MC[m, :])

            fit_params[0, m] = lreg.slope
            fit_params[1, m] = lreg.intercept

            vec = 90 - avg_angles

        elif fit_type == "Thomsen":

            thomsen_gamma, thomsen_inter = np.linalg.lstsq(M, means_MC[m, :])[0]

            fit_params[0, m] = thomsen_gamma
            fit_params[1, m] = thomsen_inter

            vec = thomsen_vec

    return fit_params, vec, avg_angles, means, means_cfi


def MC_fit_direct(which_coef, fit_type, mfile, N_MC: int = 5000):
    """
    Do the Monte-Carlo inversion directly on the values of each individual
    fit result for a given coefficient and a given fit type (either Linear
    or Thomsen).
    """

    coefs = mfile[which_coef]
    thetas = mfile["thetas"]

    if which_coef == "alpha":
        ebounds = mfile["alpha_bounds"]

    elif which_coef == "beta":
        ebounds = mfile["beta_bounds"]

    elif which_coef == "delta":
        ebounds = mfile["delta_bounds"]

    stds = ebounds[1, :] / 2  # 2*sigma contains 95% of values for a normal distribution

    # Prepare for Monte-Carlo inversion

    fit_params = np.zeros((2, N_MC))

    coefs_MC = np.zeros((N_MC, coefs.shape[0]))

    # Prepare matrix for thomsen inversion
    thomsen_vec = np.cos(thetas * np.pi / 180) ** 2
    M = np.column_stack((thomsen_vec, np.ones(thomsen_vec.shape[0])))

    for k in range(coefs.shape[0]):

        coefs_MC[:, k] = np.random.normal(coefs[k], stds[k], N_MC)

    for m in range(N_MC):

        if fit_type == "Linear":

            lreg = sstats.linregress(90 - thetas, coefs_MC[m, :])

            fit_params[0, m] = lreg.slope
            fit_params[1, m] = lreg.intercept

            vec = 90 - thetas

        elif fit_type == "Thomsen":

            thomsen_gamma, thomsen_inter = np.linalg.lstsq(M, coefs_MC[m, :])[0]

            fit_params[0, m] = thomsen_gamma
            fit_params[1, m] = thomsen_inter

            vec = thomsen_vec

    return fit_params, vec, thetas


def bootstrap_fit_direct(coefs: np.ndarray, thetas: np.ndarray, N: int = 5000):
    """
    Bootstrapping algo for inverting y = u cos^2(theta) + v from inverted
    classical and nonclassical parameter values. Not suitable for a linear regression.

    :param coefs: The coefficients for which the bootstrap fit will be performed
    :type coefs: np.ndarray

    :param thetas: The angle values corresponding to each coefficient
    :type thetas: np.ndarray

    :param N: The number of iterations for the bootstrap algorithm
    :type N: int

    :return fit_params: 2 x N array with the u and v coefficients for each iteration
    :rtype fit_params: np.ndarray

    :return vec: The cosine squared vector that was used for the inversion, with the angle values rearranged in the same order as in the `thetas` array
    :rtype vec: np.ndarray

    :return thetas: The angles used for the inversion
    :rtype thetas: np.ndarray
    """

    Ndata = coefs.shape[0]

    # Prepare arrays for final values

    fit_params = np.zeros((2, N))
    indices_bootstrap = np.zeros((Ndata))
    coefs_bootstrap = np.zeros((Ndata))
    thetas_bootstrap = np.zeros((Ndata))

    rgen = np.random.default_rng()

    for m in range(N):

        # Draw indices with replacement
        indices_bootstrap = rgen.choice(Ndata, Ndata, replace=True)

        coefs_bootstrap = coefs[indices_bootstrap]
        thetas_bootstrap = thetas[indices_bootstrap]

        # Prepare matrix for thomsen inversion
        thomsen_vec = np.cos(thetas_bootstrap * np.pi / 180) ** 2
        M = np.column_stack((thomsen_vec, np.ones(thomsen_vec.shape[0])))

        u, v = np.linalg.lstsq(M, coefs_bootstrap)[0]

        fit_params[0, m] = u
        fit_params[1, m] = v

    # Return a Thomsen vector in the right order for plotting
    vec = np.cos(thetas * np.pi / 180) ** 2

    return fit_params, vec, thetas
