import numpy as np
import mts_cwi.inversion as mtsi
import mts_cwi.plots as mtsp

from matplotlib import pyplot as plt


"""
HELPER FUNCTIONS FOR THE SUPPLEMENTARY DATA

These functions are meant to be companion functions to help navigate the code and data associated with the manuscript
"""


def load_etas():
    """
    A function to load the relative velocity change data from the provided files
    """

    etas = np.genfromtxt("../data/etas.csv", dtype=np.float64, delimiter=",")

    inversion_data = np.genfromtxt(
        "../data/inversion_data.csv",
        dtype=None,
        delimiter=",",
        skip_header=1,
        encoding=None,
    )

    Nproc = inversion_data.shape[0]

    pairs = np.empty(Nproc, dtype="object")
    thetas = np.empty(Nproc)
    reflected = np.empty(Nproc)
    alphas = np.empty(Nproc)
    betas = np.empty(Nproc)
    deltas = np.empty(Nproc)

    for proc in range(Nproc):
        pairs[proc] = inversion_data[proc][0]
        thetas[proc] = inversion_data[proc][1]
        reflected[proc] = inversion_data[proc][2]
        alphas[proc] = inversion_data[proc][3]
        betas[proc] = inversion_data[proc][4]
        deltas[proc] = inversion_data[proc][5]

    return (etas, pairs, thetas, reflected, alphas, betas, deltas, Nproc)


def plot_components(comps, t, tjs, r=16 / 9, w=20, fs=10, ax=None):

    w = w / 2.54
    h = w / r

    if ax is None:

        fig, ax = plt.subplots(figsize=(w, h), layout="constrained", facecolor="none")

    lin_trend = comps["lin_trend"]
    classical = comps["classical"]
    R = comps["R"]

    cl = "#840032"
    cr = "#e59500"
    cc = "#002642"

    ax.set_title("Individual components for the linear inversion")

    # Linear component
    ax.plot(
        t,
        lin_trend,
        color=cl,
        linewidth=3,
        linestyle="dashed",
        label=r"$L(t)$: Linear component",
    )

    # Relaxation function
    ax.plot(
        t,
        R,
        color=cr,
        linewidth=3,
        linestyle="solid",
        label=r"$R(t)$: Nonclassical component",
    )

    # Classical component
    ax.plot(
        t,
        classical,
        color=cc,
        linewidth=3,
        linestyle=":",
        label=r"$C(t)$: Classical component",
    )

    ax.set_xlabel("Time [s]")
    ax.set_ylabel("Normalized amplitude")
    ax.grid(zorder=-1)

    ax.legend(fontsize=fs)

    mtsp.compressed_shade(ax, tjs=tjs)

    return ax


def plot_example_eta(
    eta, pair, theta, reflected, t, tjs, r=16 / 9, w=20, fs=10, ax=None
):
    """
    Plot the eta measurements for a given sensor combination

    :param eta: Vector containing the velocity measurements
    :type eta:

    :param pair: Sensor combination
    :type pair: str

    :param theta: Angle of the sensor combination
    :type theta: float

    :param reflected: Whether the wave undergoes a reflection or not
    :type reflected: bool

    :param t: Time vector
    :type t: np.ndarray

    :param tjs: Loading change times
    :type tjs: np.ndarray

    :param r: Display ratio
    :type r: float

    :param w: Display width
    :type w: float

    :param fs: Font size
    :type fs: int

    :param ax: Axes on which to plot the measurements
    :type ax: AxesObject

    :return: The Axes object on which the measurements were plotted
    :rtype: AxesObject
    """

    w = w / 2.54
    h = w / r

    if ax is None:

        fig, ax = plt.subplots(figsize=(w, h), layout="constrained", facecolor="none")

    pc = "#1f449c"

    reflected_dict = {0: "direct wave", 1: "reflected wave"}

    ax.set_title(f"Combination {pair}: θ={theta:.0f}°, {reflected_dict[reflected]}")

    ax.scatter(
        t,
        eta,
        color=pc,
        marker=".",
        s=5,
        zorder=2,
        alpha=1,
    )

    ax.set_xlabel("Time [s]", fontsize=fs)
    ax.set_ylabel(r"$\eta$ [%]", fontsize=fs)

    ax.grid(zorder=-1)

    ax.set_yticks(np.arange(10) * 0.01 - 0.05)

    mtsp.compressed_shade(ax, tjs=tjs)

    return ax


def plot_example_fit(
    eta,
    alpha,
    beta,
    delta,
    pair,
    theta,
    reflected,
    t,
    tjs,
    r=16 / 9,
    w=20,
    fs=10,
    ax=None,
):
    """
    Plot the eta measurements for a given sensor combination with the corresponding model fit on top.

    :param eta: Vector containing the velocity measurements
    :type eta:

    :param alpha: Linear component coefficient
    :type alpha: float

    :param beta: Classical component coefficient
    :type beta: float

    :param delta: Nonclassical component coefficient
    :type delta: float

    :param pair: Sensor combination
    :type pair: str

    :param theta: Angle of the sensor combination
    :type theta: float

    :param reflected: Whether the wave undergoes a reflection or not
    :type reflected: bool

    :param t: Time vector
    :type t: np.ndarray

    :param tjs: Loading change times
    :type tjs: np.ndarray

    :param r: Display ratio
    :type r: float

    :param w: Display width
    :type w: float

    :param fs: Font size
    :type fs: int

    :param ax: Axes on which to plot the measurements
    :type ax: AxesObject

    :return: The Axes object on which the measurements were plotted
    :rtype: AxesObject
    """
    comps = mtsi.create_components(t, tjs, cumul=False)

    lin_trend = comps["lin_trend"]
    classical = comps["classical"]
    R = comps["R"]

    fit = alpha * lin_trend + beta * classical + delta * R

    w = w / 2.54
    h = w / r

    if ax is None:

        fig, ax = plt.subplots(figsize=(w, h), layout="constrained", facecolor="none")

    pc = "#1f449c"
    sc = "#f05039"

    reflected_dict = {0: "direct wave", 1: "reflected wave"}

    ax.set_title(f"Combination {pair}: θ={theta:.0f}°, {reflected_dict[reflected]}")

    # DV measurements
    ax.scatter(
        t,
        eta,
        color=pc,
        marker=".",
        s=5,
        zorder=2,
        alpha=1,
    )

    # Model fit
    ax.plot(t, fit, color=sc)

    ax.set_xlabel("Time [s]", fontsize=fs)
    ax.set_ylabel(r"$\eta$ [%]", fontsize=fs)

    ax.grid(zorder=-1)

    ax.set_yticks(np.arange(10) * 0.01 - 0.05)

    mtsp.compressed_shade(ax, tjs=tjs)

    return ax
