import datetime as dtt
import matplotlib.dates as mdates
import numpy as np
import os
import scipy.signal as ssig
import yaml
import time

from copy import deepcopy
from importlib import reload
from matplotlib import pyplot as plt
from obspy.core import Stats, UTCDateTime
from scipy.io import matlab
from scipy.interpolate import interp1d


def compute_eps_steps(stretch_range: float, stretch_steps: int):
    """
    Computes a vector containing the stretching coefficients epsilon that will be tested

    :param stretch_range: Maximum stretching amplitude that will be tested. Should be positive.
    :type stretch_range: float
    :param stretch_steps: Number of stretching steps. Should be an odd integer to include 0 stretching.
    :type stretch_steps: int
    """

    if stretch_range < 0:
        raise ValueError("stretch_range needs to be positive")

    if not (isinstance(stretch_steps, int)):
        raise TypeError("stretch_steps should be an integer")

    if stretch_steps % 2 != 1:
        raise ValueError("stretch_steps should be odd to account for 0 stretching")

    eps_steps = np.linspace(-stretch_range, stretch_range, stretch_steps)

    return eps_steps


def stretch_reference(ref_trace: np.ndarray, eps_steps: np.ndarray) -> np.ndarray:
    """
    Function generating the stretched versions of a given reference
    trace given a stretching range and the number of steps. Returns
    a matrix whose rows are each a different stretched version of
    the initial reference trace.

    The stretching is done as g(t) = f((1+eps)t)

    :param ref_trace: Vector containing the chosen reference trace to be stretched
    :type ref_trace: np.ndarray
    :param eps_steps: Vector with the stretching values to be tested. Positive stretching coefficient means faster wave.
    :type eps_steps: np.ndarray
    """

    stretched_ref_trace_vec = np.zeros((eps_steps.shape[0], ref_trace.shape[0]))

    t = np.arange(ref_trace.shape[0])

    for k, eps in enumerate(eps_steps):
        stretch_t = (1 + eps) * t

        stretch_fun = interp1d(t, ref_trace, "cubic", fill_value="extrapolate")

        stretched_ref_trace_vec[k, :] = stretch_fun(stretch_t)

    return stretched_ref_trace_vec


def stretch_single(ref_trace: np.ndarray, eps: float) -> np.ndarray:
    """
    Function generating a single stretched version of a given reference
    trace given a stretching value. Returns the stretched trace.

    The stretching is done as g(t) = f((1+eps)t)

    :param ref_trace: Vector containing the chosen reference trace to be stretched
    :type ref_trace: np.ndarray
    :param eps: Stretching coefficient in absolute value. Positive stretching coefficient means faster wave.
    :type eps: np.ndarray
    """

    stretched_vec = np.zeros_like(ref_trace)

    t = np.arange(ref_trace.shape[0])

    stretch_t = (1 + eps) * t

    stretch_fun = interp1d(t, ref_trace, "cubic", fill_value="extrapolate")

    stretched_vec = stretch_fun(stretch_t)

    return stretched_vec


def compute_velocity_change_stretch(
    traces: np.ndarray,
    tws: np.ndarray,
    stretched_ref_trace_vec: np.ndarray,
    eps_steps: np.ndarray,
    return_sim_mat: bool = False,
) -> dict:
    """
    Contributed by Peter Makus (https://github.com/PeterMakus/SeisMIC)

    Velocity change estimate through stretching and comparison.

    This function computes the velocity change over a selected time window (in samples) using the stretching method.
    The stretching is performed on the chosen reference trace, and then compared to the measurement.
    Therefore, we end up with eps = DV/V, as opposed to eps = -DV/V had it been the other way around.

    :type mat: :class:`~numpy.ndarray`
    :param mat: 2d ndarray containing the correlation functions.
        One for each row.
    :type tws: :class:`~numpy.ndarray` of int
    :param tws: 2d ndarray of time windows to be use in the velocity change
         estimate. Expressed in samples
    :type strrefmat: :class:`~numpy.ndarray`
    :param strrefmat: 2D array containing stretched version of the reference
         matrix
    :type strvec: :class:`~numpy.ndarray` or list
    :param strvec: Stretch amount for each row of ``strrefmat``

    :rtype: Dictionary
    :return: **dv**: Dictionary with the following keys

        *corr*: 2d ndarray containing the correlation value for the best
            match for each row of ``mat`` and for each time window.
            Its dimension is: :func:(len(tw),mat.shape[1])
        *value*: 2d ndarray containing the stretch amount corresponding to
            the best match for each row of ``mat`` and for each time window.
            Its dimension is: :func:(len(tw),mat.shape[1])
        *sim_mat*: 3d ndarray containing the similarity matrices that
            indicate the correlation coefficient with the reference for the
            different time windows, different times and different amount of
            stretching.
            Its dimension is: :py:func:`(len(tw),mat.shape[1],len(strvec))`
        *stretch_vector*: It contains the stretch vector used for the velocity
            change estimate.
        *value_type*: It is equal to 'stretch' and specifies the content of
            the returned 'value'.
        *time_windows*: Contains the array of time windows used for the processing
    """

    nstr = stretched_ref_trace_vec.shape[0]

    corr = np.zeros((len(tws), traces.shape[0]))
    dt = np.zeros((len(tws), traces.shape[0]))

    if return_sim_mat:
        sim_mat = np.zeros([traces.shape[0], len(eps_steps), len(tws)])

    for ii, tw in enumerate(tws):
        tw = tw.astype(np.int32)

        mask = np.zeros((traces.shape[1],))
        mask[tw] = 1

        ref_mask_mat = np.tile(mask, (nstr, 1))
        mat_mask_mat = np.tile(mask, (traces.shape[0], 1))

        first = traces * mat_mask_mat
        second = stretched_ref_trace_vec * ref_mask_mat

        dprod = np.dot(first, second.T)

        # Normalization
        f_sq = np.sum(first**2, axis=1)
        s_sq = np.sum(second**2, axis=1)

        f_sq = f_sq.reshape(1, len(f_sq))
        s_sq = s_sq.reshape(1, len(s_sq))

        den = np.sqrt(np.dot(f_sq.T, s_sq))

        tmp = dprod / den
        if return_sim_mat:
            sim_mat[:, :, ii] = tmp

        tmp_corr_vect = tmp.max(axis=1)
        corr[ii, :] = tmp_corr_vect
        dt[ii, :] = eps_steps[tmp.argmax(axis=1)]

    dv = {
        "corr": np.squeeze(corr),
        "value": np.squeeze(dt),
        "stretch_vector": eps_steps,
        "time_windows": tws,
    }

    if return_sim_mat:
        dv.update({"sim_mat": np.squeeze(sim_mat)})
    else:
        dv.update({"sim_mat": None})

    return dv


def compute_shifts_interpolate(
    traces: np.ndarray,
    tws: np.ndarray,
    ref_trace: np.ndarray,
    shift_steps: int = 2001,
    shift_range: float = 10,
    return_sim_mat: bool = False,
) -> dict:
    """
    Computes the shifts on a waveform array using interpolation in
    the time domain. Returns a dictionary object containing the shift values
    along with the similarity matrix containing the correlation coefficients.

    The returned delay values tau are defined as follows: for a shifted signal
    s_shift and a reference signal s_ref, we have:

    s_shift(t) = s_shift(t - tau)

    :param traces: An array where each row is one of the waveforms
    :type traces: np.ndarray

    :param tws: The time windows on which to perform the shift computation,
        in samples
    :type tws: np.ndarray

    :param ref_trace: The reference trace that will be shifted around
        for comparison with the other waveforms.
    :type ref_trace: np.ndarray

    :param shift_steps: The number of steps to perform the shifting on. Should
        be an odd integer to account for 0-shifting
    :type shift_steps: int

    :param shift_range: The range of shifting to look into, in samples. The
        function will then explore the range between [-shift_range, shift_range].
    :type shift_range: float

    :param return_sim_mat: Whether or not to return the similarity matrix along
        with the other values.
    :type return_sim_mat: bool
    """

    shift_range = np.abs(shift_range)

    shift_vec = np.linspace(-shift_range, shift_range, shift_steps)

    shift_traces = np.zeros((shift_steps, traces.shape[1]))

    s = np.arange(ref_trace.shape[0])

    interp_ref_trace_fun = interp1d(s, ref_trace, "cubic", fill_value="extrapolate")

    for k, sstep in enumerate(shift_vec):
        ss = s - sstep

        shift_traces[k, :] = interp_ref_trace_fun(ss)

    corr = np.zeros((len(tws), traces.shape[0]))
    dt = np.zeros((len(tws), traces.shape[0]))
    sim_mat = np.zeros([traces.shape[0], shift_steps, len(tws)])

    for ii, tw in enumerate(tws):
        tw = tw.astype(np.int32)

        mask = np.zeros((traces.shape[1],))
        mask[tw] = 1

        shift_mask_mat = np.tile(mask, (shift_steps, 1))
        mat_mask_mat = np.tile(mask, (traces.shape[0], 1))

        first = traces * mat_mask_mat
        second = shift_traces * shift_mask_mat

        dprod = np.dot(first, second.T)

        # Normalization
        f_sq = np.sum(first**2, axis=1)
        s_sq = np.sum(second**2, axis=1)

        f_sq = f_sq.reshape(1, len(f_sq))
        s_sq = s_sq.reshape(1, len(s_sq))

        den = np.sqrt(np.dot(f_sq.T, s_sq))

        tmp = dprod / den
        sim_mat[:, :, ii] = tmp

        tmp_corr_vect = tmp.max(axis=1)
        corr[ii, :] = tmp_corr_vect

        dt[ii, :] = shift_vec[tmp.argmax(axis=1)]

    shifts = {
        "corr": np.squeeze(corr),
        "value": np.squeeze(dt),
        "stretch_vector": shift_vec,
        "time_windows": tws,
    }

    if return_sim_mat:
        shifts.update({"sim_mat": np.squeeze(sim_mat)})

    return shifts


def correct_shift(shift, wf):
    wf_before_correction = wf

    t = np.arange(wf.shape[0])

    t_shift = t - shift

    wf_interp_fun = interp1d(t, wf_before_correction, "cubic", fill_value="extrapolate")

    wf_after_correction = wf_interp_fun(t_shift)

    return wf_after_correction


def filter_wf(
    waveforms: np.ndarray, fsamp: float, freqs: list = [None, None], order: int = 3
):
    """
    Filters the waveforms of the ADP object using a Butterworth filter of order 'order'.

    :param waveforms: A numpy array where each row is a waveform to be filtered
    :type waveforms: np.ndarray

    :param fsamp: The sampling frequency of the traces in Hertz
    :type fsamp: float

    :param freqs: A tuple containing the frequencies to be used for the filtering.
        If the first component is zero or None, the filter acts as a low-pass.
        If the second component is zero or None, it is a high-pass.
        If both are floats, then it acts as a bandpass filter.
    :type freqs: tuple

    :param order: The integer order to be used for the Butterworth filter.
    :type order: int
    """

    fr = np.array(freqs)

    if fr[1] == fsamp / 2:
        fr[1] = (1 - 1e-5) * fr[1]

    if fr[0] is None and fr[1] is None:
        return waveforms

    if fr[0] is None or fr[0] == 0:
        method = "low"
        filter_freqs = fr[1]
    elif fr[1] is None or fr[1] == 0:
        method = "high"
        filter_freqs = fr[0]
    else:
        method = "band"
        filter_freqs = fr

    sos = ssig.butter(order, filter_freqs, btype=method, fs=fsamp, output="sos")

    waveforms = ssig.sosfilt(sos, waveforms)

    return waveforms


def rms(vec: np.ndarray):
    """
    Compute the root mean square amplitude of a 1D vector.

    :param vec: Vector whose RMS needs to be computed
    :type vec: np.ndarray
    """

    vecf = vec.astype(float)

    res = np.sqrt(np.mean(vecf**2))

    return res
