import datetime as dtt
import matplotlib as mpl
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import numpy as np
import os
import scipy.stats as sstats
import time
import yaml
import mts_cwi.stretch as mtss
import mts_cwi.mts_data as mtsd
import mts_cwi.inversion as mtsi
from copy import deepcopy
from importlib import reload
from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.patches import Rectangle
from obspy.core import Stats, UTCDateTime
from scipy.io import matlab


def shape_plot(number_of_plots):
    """
    Function that returns a (p,q) tuple to shape a figure containing number_of_plots
    subplots into something that resembles a rectangle/square
    """

    n = number_of_plots

    q = int(np.ceil(np.sqrt(n)))

    if (q**2 - n) >= q:
        p = q - 1
    else:
        p = q

    return (p, q)


def compressed_shade(ax, tjs, sc="#f05039", alpha=0.2, set_axis=True):

    ylims = ax.get_ylim()

    ax.fill_betweenx(
        ylims, [tjs[0], tjs[0]], [tjs[1], tjs[1]], color=sc, alpha=alpha, zorder=-1
    )
    ax.fill_betweenx(
        ylims, [tjs[2], tjs[2]], [tjs[3], tjs[3]], color=sc, alpha=alpha, zorder=-1
    )

    if set_axis:
        ax.set_ylim(ylims)

    return ax


def plot_dv(
    eh,
    dv,
    tw_id: int = 0,
    a: int = 4,
    r: float = 16 / 9,
    datelim: list = [None, None],
    ax=None,
):
    """
    Plots the DV/V measurements for a given time window measurement that was obtained
    with the stretching method. By default, it plots it on a separate figure, but it
    can also plot it on an already-existing Axes object.
    """

    stime = dv.sourcetimes

    dv_val = np.atleast_2d(dv.value)[tw_id, :]
    tws = dv.time_windows

    time = np.array([utcdt.datetime for utcdt in stime], dtype=np.datetime64)

    if ax is None:
        fig = plt.figure(
            figsize=(r * a, a), dpi=180, facecolor="white", edgecolor="none"
        )
        ax = plt.axes()
        fig.add_axes(ax)

    ax.xaxis.set_major_locator(mdates.HourLocator(byhour=range(24), interval=1))
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))

    ax.set_xlabel("Time [h]", size="x-large")
    ax.set_ylabel("DV/V [%]", size="x-large")

    ax.grid()
    ax.set_axisbelow(True)

    ax.plot(time, 100 * dv_val, "o", markersize=2, linewidth=0.5)

    window_times = np.atleast_2d(eh.tw_to_seconds(tws))[tw_id]

    title = f"{dv.pair}: [{window_times[0]:.1f}μs - {window_times[1]:.1f}μs]"

    """
    if hasattr(adp, "frequency_range"):
        if adp.frequency_range[0] is None:
            f1 = 0
        else:
            f1 = adp.frequency_range[0]

        if adp.frequency_range[1] is None:
            f2 = 0.5 * adp.sampling_rate
        else:
            f2 = adp.frequency_range[1]

        title += f", {f1:.2e} - {f2:.2e} Hz"
        """

    ax.set_title(title)

    if datelim[0] is None:
        dstart = stime[0]
    else:
        dstart = np.amax([stime[0], UTCDateTime(datelim[0])])

    if datelim[-1] is None:
        dend = stime[-1]
    else:
        dend = np.amin([stime[-1], UTCDateTime(datelim[-1])])

    ax.set_xlim(dstart, dend)

    return ax


def plot_shifts(
    adp,
    tw_id: int = 0,
    a: int = 4,
    r: float = 16 / 9,
    datelim: list = [None, None],
    ax=None,
):
    """
    Plots the shift measurements for a given time window measurement that was obtained
    with the interpolation. By default, it plots it on a separate figure, but it
    can also plot it on an already-existing Axes object.
    """

    stime = adp.sourcetimes

    shifts = np.atleast_2d(adp.shifts["value"])[tw_id, :]
    tws = adp.shifts["time_windows"]

    time = np.array([utcdt.datetime for utcdt in stime], dtype=np.datetime64)

    if ax is None:
        fig = plt.figure(
            figsize=(r * a, a), dpi=180, facecolor="white", edgecolor="none"
        )
        ax = plt.axes()
        fig.add_axes(ax)

    ax.xaxis.set_major_locator(mdates.HourLocator(byhour=range(24), interval=1))
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))

    ax.set_xlabel("Time [h]", size="x-large")
    ax.set_ylabel("Shifts [μs]", size="x-large")

    ax.grid()
    ax.set_axisbelow(True)

    ax.plot(time, 0.1 * shifts, "o", markersize=2, linewidth=0.5)

    window_times = adp.tw_to_seconds(tws)[tw_id]

    title = f"Shifts [{window_times[0]:.1f}μs - {window_times[1]:.1f}μs]"

    if hasattr(adp, "frequency_range"):
        if adp.frequency_range[0] is None:
            f1 = 0
        else:
            f1 = adp.frequency_range[0]

        if adp.frequency_range[1] is None:
            f2 = 0.5 * adp.sampling_rate
        else:
            f2 = adp.frequency_range[1]

        title += f", {f1:.2e} - {f2:.2e} Hz"

    ax.set_title(title)

    if datelim[0] is None:
        dstart = stime[0]
    else:
        dstart = np.amax([stime[0], UTCDateTime(datelim[0])])

    if datelim[-1] is None:
        dend = stime[-1]
    else:
        dend = np.amin([stime[-1], UTCDateTime(datelim[-1])])

    ax.set_xlim(dstart, dend)

    return ax


def plot_correlation_coef(
    adp,
    measurement_type: str = "shifts",
    tw_id: int = 0,
    a: int = 4,
    r: float = 16 / 9,
    datelim: list = [None, None],
    ax=None,
    sim_mat=False,
):
    """
    Plots the correlation coefficients for a given time window for either DV/V or shifts.
    By default, it plots it on a separate figure, but it
    can also plot it on an already-existing Axes object.
    """

    if measurement_type == "shifts":
        if not (hasattr(adp, "shifts")) or adp.shifts is None:
            raise ValueError("Shifts could not be found")

        else:
            corr = np.atleast_2d(adp.shifts["corr"])[tw_id, :]
            tws = adp.shifts["time_windows"]

    elif measurement_type == "dv":
        if not (hasattr(adp, "dv")) or adp.dv is None:
            raise ValueError("DV/V could not be found")

        else:
            corr = np.atleast_2d(adp.dv["corr"])[tw_id, :]
            tws = adp.dv["time_windows"]

    else:
        raise ValueError("Measurement type must be either shifts or dv")

    stime = adp.sourcetimes

    time = np.array([utcdt.datetime for utcdt in stime], dtype=np.datetime64)

    if ax is None:
        fig = plt.figure(
            figsize=(r * a, a), dpi=180, facecolor="white", edgecolor="none"
        )
        ax = plt.axes()
        fig.add_axes(ax)

    ax.xaxis.set_major_locator(mdates.HourLocator(byhour=range(24), interval=1))
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))

    ax.set_xlabel("Time [h]", size="x-large")
    ax.set_ylabel("Correlation coefficient", size="x-large")

    ax.grid()
    ax.set_axisbelow(True)

    ax.plot(time, corr, "o", markersize=2, linewidth=0.5)

    window_times = adp.tw_to_seconds(tws)[tw_id]

    title = (
        f"CC for {measurement_type} [{window_times[0]:.1f}μs - {window_times[1]:.1f}μs]"
    )

    if hasattr(adp, "frequency_range"):
        if adp.frequency_range[0] is None:
            f1 = 0
        else:
            f1 = adp.frequency_range[0]

        if adp.frequency_range[1] is None:
            f2 = 0.5 * adp.sampling_rate
        else:
            f2 = adp.frequency_range[1]

        title += f", {f1:.2e} - {f2:.2e} Hz"

    ax.set_title(title)

    if datelim[0] is None:
        dstart = stime[0]
    else:
        dstart = np.amax([stime[0], UTCDateTime(datelim[0])])

    if datelim[-1] is None:
        dend = stime[-1]
    else:
        dend = np.amin([stime[-1], UTCDateTime(datelim[-1])])

    ax.set_xlim(dstart, dend)

    return ax


def plot_waveform(
    adp,
    waveform_id: np.ndarray = np.array([0]),
    tws: np.ndarray = None,
    a: int = 4,
    r: float = 16 / 9,
    tlim: list = [None, None],
    ax=None,
):
    """
    Plotting function for the acoustic trace for a given event ID.
    Defaults to the first trace.

    :param adp: the acousticDataPair object whose waveforms we want to plot
    :type adp: acousticDataPair

    :param waveform_id: one or multiple waveform IDs that we want to plot
    :param waveform_id: np.ndarray

    :param tws: one or multiple time windows (as 1D or 2D arrays) to be displayed on
        the waveform, in microseconds
    :type tws: np.ndarray

    :param tlim: set the time bounds for the plot. Either bound can be set to None
    :type tlim: list

    :param ax: Axes object on which to plot the waveform. If no object is passed
        as an argument, a figure is created and the function does not return anything.
    :type ax: matplotlib.axes
    """

    waveform_id = np.atleast_1d(waveform_id)
    Nwf = waveform_id.shape[0]

    t = (np.arange(0, adp.signal_length) - adp.trigger) * (1 / float(adp.sampling_rate))
    tscale = 10**6  # Convert seconds to microseconds on the x-axis

    t = t * tscale

    if ax is None:
        fig = plt.figure(
            figsize=(r * a, a), dpi=180, facecolor="white", edgecolor="none"
        )
        ax = plt.axes()
        fig.add_axes(ax)

    cmap = mpl.cm.get_cmap("rainbow")

    for k in range(Nwf):
        trace = adp.waveforms[waveform_id[k], :]

        ax.plot(
            t,
            trace,
            color=cmap(k / Nwf),
            linewidth=0.5,
            label=str(waveform_id[k]),
        )

    title = f"Trace n.{np.squeeze(waveform_id)} for {adp.pair}"

    if hasattr(adp, "frequency_range"):
        if adp.frequency_range[0] is None:
            f1 = 0
        else:
            f1 = adp.frequency_range[0]

        if adp.frequency_range[1] is None:
            f2 = 0.5 * adp.sampling_rate
        else:
            f2 = adp.frequency_range[1]

        title += f", {f1:.2e} - {f2:.2e} Hz"

    ax.set_title(title)

    ax.set_xlabel("Time [μs]", size="large")
    ax.set_ylabel("Amplitude [AU]", size="large")

    ax.grid()
    ax.set_axisbelow(True)

    ax.margins(x=0)

    ax.legend()

    if not (tws is None):
        tws = adp.seconds_to_tw(tws)

        for tw in tws:
            tw = tw.astype(int)

            x = t[tw[0]]
            y = -1.1 * np.amax(np.abs(trace[tw]))

            w = t[tw[-1]] - t[tw[0]]
            h = 1.1 * np.amax(np.abs(trace[tw]))

            ax.add_patch(
                Rectangle(
                    (x, y), w, 2 * h, linewidth=1, edgecolor="r", facecolor="none"
                )
            )

    if tlim[0] is None:
        tstart = t[0]
    else:
        tstart = np.amax([t[0], tlim[0]])

    if tlim[-1] is None:
        tend = t[-1]
    else:
        tend = np.amin([t[-1], tlim[-1]])

    tw_samples = np.squeeze(adp.seconds_to_tw([tstart, tend])).astype(int)

    ymax = (1 + 0.1 * np.sign(np.amax(trace[tw_samples]))) * np.amax(trace[tw_samples])
    ymin = (1 - 0.1 * np.sign(np.amin(trace[tw_samples]))) * np.amin(trace[tw_samples])

    ax.set_xlim(tstart, tend)
    ax.set_ylim(ymin, ymax)

    return ax


def plot_sensor_map(
    pair: str = "08-06",
    coord: list = [(0, 0), (0, 0)],
    center_id: int = 3,
    a: int = 4,
    r: float = 16 / 9,
    ax=None,
):
    pulseind = int(pair[:2]) - 1
    receiverind = int(pair[-2:]) - 1

    # Center map around pulsing transducer
    s = center_id - coord[pulseind][1]

    Ntrans = len(coord)

    for k in range(Ntrans):
        coord[k] = (coord[k][0], (coord[k][1] + s) % 8)

    xscale = 45  # in degrees
    yscale = 1  # in centimeters

    xcoord = [coord[k][1] * xscale - xscale * center_id for k in range(Ntrans)]
    ycoord = [(6 - (coord[k][0]) + 2) * yscale for k in range(Ntrans)]

    xpul = xcoord[pulseind]
    ypul = ycoord[pulseind]

    xrec = xcoord[receiverind]
    yrec = ycoord[receiverind]

    if ax is None:
        fig = plt.figure(
            figsize=(r * a, a), dpi=180, facecolor="white", edgecolor="none"
        )
        ax = plt.axes()
        fig.add_axes(ax)

    ax.plot(
        (xpul, xrec),
        (ypul, yrec),
        color="black",
        linestyle="--",
        linewidth=a / 2,
        zorder=0.75,
    )

    circ_xscale = 2 * 0.25 * xscale
    circ_yscale = 2 * 0.35 * r * yscale

    for k in range(Ntrans):
        if k == pulseind:
            circa = Ellipse(
                (xcoord[k], ycoord[k]),
                circ_xscale,
                circ_yscale,
                facecolor="r",
                linestyle="-",
                linewidth=a / 2,
                edgecolor="black",
            )
        elif k == receiverind:
            circa = Ellipse(
                (xcoord[k], ycoord[k]),
                circ_xscale,
                circ_yscale,
                facecolor="g",
                linestyle="-",
                linewidth=a / 2,
                edgecolor="black",
            )
        else:
            circa = Ellipse(
                (xcoord[k], ycoord[k]),
                circ_xscale,
                circ_yscale,
                facecolor="b",
                linestyle="-",
                linewidth=a / 2,
                edgecolor="black",
            )

        ax.text(
            xcoord[k],
            ycoord[k],
            f"{k+1:02d}",
            color="white",
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=3 * a,
        )

        ax.add_patch(circa)

    xticks = np.linspace(-360, 360, 17)

    ax.set_xticks(xticks)
    ax.set_xlim(-center_id * xscale - 10, (7 - center_id) * xscale + 10)
    ax.set_xlabel("Angle [°]")

    ax.set_ylim(0, 10 * yscale)
    ax.set_yticks(np.linspace(0, 10 * yscale, 11))
    ax.set_ylabel("Height [cm]")

    title = f"Transducer map for source-receiver pair {pair}"

    ax.set_title(title)

    ax.grid()

    ax.set_axisbelow(True)

    return ax


def plot_fit_results(
    coefs,
    t,
    dvval,
    components,
    pair,
    theta,
    Nref,
    ebounds=None,
    tau_max: float = 1.5e4,
    r: float = 4 / 3,
    a: float = 7.5,
    axd=None,
):
    """
    Plots the results of inverting a set of DV/V data for linear trend, classical
    and nonclassical parameters.
    """

    if axd is None:
        fig, axd = plt.subplot_mosaic(
            [["DV", "DV"], ["Parameters", "Residuals"]],
            layout="constrained",
            figsize=(r * a, a),
        )

    L = components["lin_trend"]
    C = components["classical"]
    R = components["R"]

    model = coefs[0] * L + coefs[1] * C + coefs[2] * R

    refornot = {0: "direct P arrival", 1: "reflected P arrival"}

    fig.suptitle(f"Best fit for {pair}, {theta:.0f}°, {refornot[Nref]}")

    # DV AND MODEL
    axd["DV"].set_title(r"$\tau_{max}$" + f"={tau_max:.3e}")

    axd["DV"].plot(t, model, color="#f05039", linewidth=2)
    axd["DV"].scatter(t, dvval, color="#1f449c", s=0.75, alpha=0.85, zorder=3)

    axd["DV"].set_xlabel("Time [s]")
    axd["DV"].set_ylabel("DV/V [%]")
    axd["DV"].yaxis.set_major_formatter(ticker.FormatStrFormatter("%.1e"))
    axd["DV"].grid()

    # FITTING PARAMETERS
    axd["Parameters"].set_title("Inversion coefficients")

    coefficients_str = (r"$\alpha$", r"$\beta$", r"$\delta$")
    y_pos = np.arange(len(coefficients_str))

    # Display originally negative coefficients in red
    color = np.full_like(coefs, "#1f449c", dtype=object)
    color[np.sign(coefs) < 0] = "#f05039"

    if ebounds is None:
        ebounds = np.empty((2, len(coefs)))

        for k, which_coef in enumerate(["alpha", "beta", "delta"]):
            ebounds[:, k] = mtsi.compute_confidence_interval(
                which_coef, coefs, dvval, components
            )

    hbars = axd["Parameters"].barh(
        y_pos,
        coefs,
        align="center",
        color=color,
        xerr=ebounds,
        capsize=5,
        alpha=0.75,
    )

    # Label the bars
    axd["Parameters"].bar_label(
        hbars, fmt="%.2e", label_type="edge", color="black", padding=5
    )

    axd["Parameters"].set_xlabel("Coefficient value")
    axd["Parameters"].set_xlim(left=-2e-2, right=7e-2)  # adjust xlim to fit
    axd["Parameters"].xaxis.grid(True, zorder=-1)

    axd["Parameters"].set_yticks(y_pos, labels=coefficients_str)
    axd["Parameters"].invert_yaxis()  # labels read top-to-bottom

    axd["Parameters"].set_axisbelow(True)

    residual_vec = dvval - model

    # RESIDUALS DISTRIBUTION

    # Computing the kurtosis of the residuals and comparing to normal distribution
    kurt = sstats.kurtosis(residual_vec)
    res = sstats.kurtosistest(residual_vec)

    # Parameters for the histogram
    Nbins = 50
    bins = np.linspace(np.amin(residual_vec), np.amax(residual_vec), Nbins)

    # Fitting the gaussian curve
    mu, std = sstats.norm.fit(residual_vec)
    x = np.linspace(
        1.2 * np.amin(residual_vec), 1.2 * np.amax(residual_vec), 10 * Nbins
    )
    g = sstats.norm.pdf(x, mu, std)

    # Plotting the histogram
    axd["Residuals"].set_title(
        f"\nClosest gaussian: μ = {mu:.2e}, σ = {std:.2e} \nExcess kurtosis: {kurt:.2f}, p-value: {res.pvalue:.2e}"
    )

    axd["Residuals"].hist(residual_vec, color="#1f449c", bins=bins, density=True)
    axd["Residuals"].plot(x, g, color="#f05039", linewidth=2, alpha=0.75)

    axd["Residuals"].grid()
    axd["Residuals"].set_axisbelow(True)

    xlim = axd["Residuals"].get_xlim()
    xlim = [-np.abs(max(xlim[1], -xlim[0])), np.abs(max(xlim[1], -xlim[0]))]

    axd["Residuals"].set_xlim(xlim)

    return axd


def plot_angle_dep(
    which_coef: str,
    mfile,
    r: float = 4 / 3,
    a: float = 4,
    plot_data: bool = True,
    plot_means: bool = True,
    ax=None,
    cax=None,
    dir_and_ref=True,
):
    plotval = mfile[which_coef]
    thetas = mfile["thetas"]

    if which_coef == "alpha":
        strcoef = r"$\alpha$"
        ebounds = mfile["alpha_bounds"]

    elif which_coef == "beta":
        strcoef = r"$\beta$"
        ebounds = mfile["beta_bounds"]

    elif which_coef == "delta":
        strcoef = r"$\delta$"
        ebounds = mfile["delta_bounds"]

    else:
        raise ValueError("which_coef must be 'alpha', 'beta' or 'delta'")

    if ax is None:
        fig, ax = plt.subplots(figsize=(r * a, a), layout="constrained")
        fig.suptitle("Angle dependency for " + strcoef)
    else:
        ax.set_title("Angle dependency for " + strcoef)

    # Plotting the averaged velocity change per angle bin
    angle_vals, bins = np.histogram(thetas, 51, range=(39, 90))

    nbins = angle_vals.shape[0]

    means = np.zeros((nbins,))
    stds = np.zeros((nbins,))
    avg_angles = np.zeros((nbins,))

    arr = plotval

    med_err = np.median(ebounds[1, :])

    for kbin in range(nbins):
        if angle_vals[kbin] == 0:
            continue

        brr = bins[kbin]
        crr = bins[kbin + 1]

        bin_vals = np.intersect1d(arr[thetas > brr], arr[thetas <= crr])
        bin_angles = np.intersect1d(thetas[thetas > brr], thetas[thetas <= crr])

        # Arithmetic mean and std
        means[kbin] = np.average(bin_vals)

        stds[kbin] = (med_err / 2) / np.sqrt(bin_vals.shape[0])

        avg_angles[kbin] = np.average(bin_angles)

    means = means[angle_vals != 0]
    stds = stds[angle_vals != 0]
    avg_angles = avg_angles[angle_vals != 0]
    angles = bins[1:][angle_vals != 0]

    if plot_means:

        ax.scatter(
            avg_angles,
            means,
            color="white",
            edgecolors="black",
            s=7 * a,
            zorder=3,
            marker="D",
        )

        ax.errorbar(
            avg_angles,
            means,
            yerr=2 * stds,
            ecolor="black",
            capsize=2,
            fmt="none",
        )

    if plot_data:
        cmap = mpl.colormaps["plasma"]
        norm = mpl.colors.Normalize(0, 1.5e-2)

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        # sm.set_clim(vmin=0, vmax=5e-2)
        if cax is None:
            plt.colorbar(sm, ax=ax, label="Uncertainty")
        else:
            plt.colorbar(sm, cax=cax, label="Uncertainty")

        colors = cmap(norm(ebounds[1, :]))

        if dir_and_ref:
            dinds = mfile["Nrefs"] == 0
            rinds = mfile["Nrefs"] == 1

            ax.scatter(
                thetas[dinds],
                plotval[dinds],
                marker="o",
                color=colors[dinds],
                s=10 * a,
                zorder=2,
            )
            ax.scatter(
                thetas[rinds],
                plotval[rinds],
                marker="*",
                color=colors[rinds],
                s=10 * a,
                zorder=2,
            )
        else:
            ax.scatter(thetas, plotval, marker=".", color=colors, s=10 * a, zorder=2)

    ax.set_facecolor("white")

    ax.set_xlabel(r"$\theta$" + " [°]")
    ax.set_ylabel(strcoef)

    # Axis limits
    ax.set_ylim([-0.01, 0.05])

    return ax


def plot_MC_fit_full(
    which_coef,
    fit_type,
    MC_res,
    r: float = 20 / 9,
    a: float = 4.5,
    axd=None,
):

    fit_params, vec, angles, means, means_cfi = MC_res
    N_MC = fit_params.shape[1]

    if which_coef == "alpha":
        strcoef = r"$\alpha$"

    elif which_coef == "beta":
        strcoef = r"$\beta$"

    elif which_coef == "delta":
        strcoef = r"$\delta$"

    if fit_type == "Thomsen":
        strfit = r"$\mathrm{u cos^2(\theta) + v}$"
    elif fit_type == "Linear":
        strfit = r"$\mathrm{u (90 - \theta) + v}$"

    if axd is None:
        fig, axd = plt.subplot_mosaic(
            [["Data", "Data", "Fit1"], ["Data", "Data", "Fit2"]],
            figsize=(r * a, a),
            layout="tight",
        )

    axd["Data"].plot(
        angles,
        np.outer(vec, fit_params[0, :]) + fit_params[1, :],
        color="orange",
        linewidth=0.1,
        alpha=0.1,
        zorder=-1,
    )

    u, v = np.average(fit_params, 1)

    # Data plot
    axd["Data"].set_title(
        f"{fit_type} fit for " + strcoef + ": " + strfit + f" ({N_MC} runs)"
    )
    axd["Data"].scatter(
        angles, means, marker="D", color="white", edgecolors="black", zorder=5
    )
    axd["Data"].errorbar(
        angles, means, means_cfi, fmt="None", color="black", capsize=2, zorder=-1
    )

    axd["Data"].plot(angles, u * vec + v, color="red")

    axd["Data"].set_xlabel("Angle [°]")
    axd["Data"].set_ylabel(strcoef)
    axd["Data"].set_ylim(0, 5e-2)

    # Parameters 1

    axd["Fit1"].set_title("u")
    h1 = axd["Fit1"].hist(fit_params[0, :], 20, color="blue", alpha=0.3)
    axd["Fit1"].vlines(
        u, ymin=0, ymax=np.amax(h1[0]), linestyle=":", color="blue", linewidth=2
    )
    axd["Fit1"].xaxis.set_major_formatter(ticker.FormatStrFormatter("%.1e"))
    axd["Fit1"].tick_params(axis="x", labelrotation=45)

    # Parameter 2

    axd["Fit2"].set_title("v")
    h2 = axd["Fit2"].hist(fit_params[1, :], 20, color="orange", alpha=0.3)
    axd["Fit2"].vlines(
        v,
        ymin=0,
        ymax=np.amax(h2[0]),
        linestyle=":",
        color="orange",
        linewidth=2,
    )
    axd["Fit2"].xaxis.set_major_formatter(ticker.FormatStrFormatter("%.1e"))
    axd["Fit2"].tick_params(axis="x", labelrotation=45)

    if fit_type == "Linear":
        axd["Fit1"].set_xlim([0, 6e-4])
        axd["Fit2"].set_xlim([-1e-2, 3e-2])

    elif fit_type == "Thomsen":
        axd["Fit1"].set_xlim([0, 5e-2])
        axd["Fit2"].set_xlim([0, 3e-2])

    return axd


def plot_MC_fit_means(
    which_coef,
    fit_type,
    MC_res,
    r: float = 16 / 9,
    a: float = 4.5,
    ax=None,
):

    fit_params, vec, angles, means, means_cfi = MC_res
    N_MC = fit_params.shape[1]

    if which_coef == "alpha":
        strcoef = r"$\alpha$"

    elif which_coef == "beta":
        strcoef = r"$\beta$"

    elif which_coef == "delta":
        strcoef = r"$\delta$"

    if fit_type == "Thomsen":
        strfit = r"$\mathrm{u cos^2(\theta) + v}$"
    elif fit_type == "Linear":
        strfit = r"$\mathrm{u (90 - \theta) + v}$"

    if ax is None:
        fig, ax = plt.subplots(
            1,
            1,
            figsize=(r * a, a),
            layout="tight",
        )

    ax.plot(
        angles,
        np.outer(vec, fit_params[0, :]) + fit_params[1, :],
        color="orange",
        linewidth=0.1,
        alpha=0.1,
        zorder=-1,
    )

    u, v = np.average(fit_params, 1)

    # Data plot
    ax.set_title(f"{strcoef}")
    ax.scatter(angles, means, marker="D", color="white", edgecolors="black", zorder=5)
    ax.errorbar(
        angles, means, means_cfi, fmt="None", color="black", capsize=2, zorder=-1
    )

    ax.plot(angles, u * vec + v, color="red")

    ax.set_xlabel("Angle [°]")
    ax.set_ylabel(strcoef)
    ax.set_ylim(0, 5e-2)

    return ax


def plot_anisotropy_fit_direct(
    res,
    r: float = 16 / 9,
    a: float = 4.5,
    ax=None,
):
    """
    Plotting function for the inverted anisotropy law for a given dataset. Plots all
    individual iterations as well as the average model.

    :param which_coef: The coefficient to plot, can be "alpha", "beta" or "delta"
    :type which_coef: str

    :param fit_type: The function used as a base for the anisotropy inversion. Either "Thomsen" or "Linear"
    :type fit_type: str

    :param res: The results from the inversion (bootstrap or Monte-Carlo).
    :type res: tuple

    :param ax: Axes object on which to plot the waveform. If no object is passed
        as an argument, a figure is created and the function does not return anything.
    :type ax: matplotlib.axes

    :returns ax: The Axes object on which everything was plotted.
    :type ax: matplotlib.axes

    :returns anres: The average results from the anisotropy inversion and the standard error, in a tuple: (u, sigma_u, v, sigma_v)
    :type anres: tuple
    """

    fit_params, vec, thetas = res
    inds = np.argsort(thetas)

    if ax is None:
        fig, ax = plt.subplots(
            1,
            1,
            figsize=(r * a, a),
            layout="tight",
        )

    #    ax.set_title(strcoef)

    # Plot the individual iteration models
    ax.plot(
        thetas[inds],
        np.outer(vec[inds], fit_params[0, :]) + fit_params[1, :],
        color="orange",
        linewidth=0.1,
        alpha=0.1,
        zorder=-1,
    )

    # Plot the average model
    u, v = np.average(fit_params, 1)
    uerr, verr = np.std(fit_params, 1)

    anres = (u, uerr, v, verr)

    ax.plot(thetas[inds], u * vec[inds] + v, color="red")

    ax.set_xlabel("Angle [°]")
    #    ax.set_ylabel(strcoef)
    ax.set_ylim(-1e-2, 5e-2)

    return ax, anres
