import datetime as dtt
import matplotlib.dates as mdates
import numpy as np
import os
import yaml
import mts_cwi.stretch as mtss
import mts_cwi.plots as mtsp
import pickle
import pprint
import scipy.signal as ssig
import scipy.stats as sstats
import scipy.fft as sf

from copy import deepcopy
from importlib import reload
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d import Axes3D
from obspy.core import Stats, UTCDateTime
from scipy.io import matlab
from collections import defaultdict

import mts_cwi.index_utils as mtsiu

"""
---------------------------------------------------------------
                    EXPERIMENT HANDLER CLASS
---------------------------------------------------------------
"""


class expHandler(object):
    def __init__(self, expid: int):
        self.expid = expid

        # Load parameters from parameter file
        with open("./cfg/experiments.yaml") as file:
            sparam = yaml.load(file, Loader=yaml.FullLoader)

        self.exp = sparam["exps"][int(expid)]

        self.signal_length = int(self.exp["signal_length"])
        self.sampling_rate = float(self.exp["sampling_rate"])
        self.trigger = int(self.exp["trigger"])

        self.expname = self.exp["expname"]
        self.dax_start = self.exp["dax_start"]
        self.mts_start = self.exp["mts_start"]

        self.basepath = f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{self.expname}/"

        self.adps = {}

        if expid <= 20:
            sloc = sparam["sensor_locations"]["Ben1_Ben4"]
            self.gm = geometry(diameter=50, length=100, sensor_locations=sloc)

        print(self)

    def __str__(self):
        out1 = (
            f"Experiment: {self.expid} - {self.expname}\n"
            f"Experiment date (mechanical data): {self.mts_start}\n"
            f"Experiment date (acoustic data): {self.dax_start}\n"
            f"Path: {self.basepath}\n"
        )

        list_adps = []

        for pair, adp in self.adps.items():
            list_adps.append(pair)

        out2 = f"Available ADPs: {list_adps}\n"

        return out1 + out2

    def load_adp(
        self,
        pair: str,
        loadpath: str = None,
        processing: str = "nofilt",
        store: bool = False,
    ):
        """
        Reads a saved acousticDataPair object from an .npz file.
        Can take either a file path or both an expid and a sensor pair as input.

        :param expid: Experiment ID as stated in the experiments.yaml file
        :type expid: int

        :param pair: Sensor pair in the '08-06' format
        :type pair: str

        :param loadpath: Path to file
        :type loadpath: str

        :param processing: A string containing the name of the chosen processing parameters for the analysis, defaults to 'nofilt'
        :type processing: str

        :return: the corresponding acousticDataPair object
        :rtype: acousticDataPair
        """

        expid = self.expid

        if loadpath is None:
            if (pair is None) or (expid is None):
                raise ValueError(
                    "If loadpath was not specified, you need to enter both an expid and a pair"
                )

            else:
                with open("./cfg/experiments.yaml") as file:
                    sparam = yaml.load(file, Loader=yaml.FullLoader)

                expname = sparam["exps"][expid]["expname"]

                loadpath = f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{expname}/data/{processing}/{pair[:2]}/{pair[-2:]}/adp/adp_{pair}.npz"

        loaded = np.load(loadpath, allow_pickle=True)

        expid = int(loaded["expid"])
        pair = str(loaded["pair"])

        waveforms = loaded["waveforms"]
        sourcetimes = loaded["sourcetimes"]

        was_corrected = bool(loaded["was_corrected"])
        pulse_shifts = loaded["pulse_shifts"]

        loaded.close()

        adp = acousticDataPair(
            expid=expid,
            pair=pair,
            waveforms=waveforms,
            sourcetimes=sourcetimes,
            was_corrected=was_corrected,
            pulse_shifts=pulse_shifts,
        )

        if store:
            self.adps.update({pair: adp})

        return adp

    def load_dv(
        self,
        pair: str = None,
        tw: np.ndarray = None,
        loadpath: str = None,
        processing: str = "nofilt",
    ):  # -> DV:
        """
        Reads a saved DV object from an .npz file.
        Can take either a file path or both an expid and a sensor pair as input.

        :param expid: Experiment ID as stated in the experiments.yaml file
        :type expid: int

        :param pair: Sensor pair in the '08-06' format
        :type pair: str

        :param tw: The time window used for the DV computation, expressed in seconds. If not specified, the function just returns a list of the available time windows for the chosen expid
        and pair.
        :type tw: np.ndarray

        :param loadpath: Path to file
        :type loadpath: str

        :param processing: A string containing the name of the chosen processing parameters for the analysis, defaults to 'nofilt'
        :type processing: str

        :return: the corresponding DV object
        :rtype: DV
        """

        expid = self.expid

        if loadpath is None:
            if (pair is None) or (expid is None):
                raise ValueError(
                    "If loadpath was not specified, you need to enter both an expid, a pair and a time window"
                )

            else:
                with open("./cfg/experiments.yaml") as file:
                    sparam = yaml.load(file, Loader=yaml.FullLoader)

                expname = sparam["exps"][expid]["expname"]

                if tw is None:
                    list_files = os.listdir(
                        f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{expname}/data/{processing}/{pair[:2]}/{pair[-2:]}/dv/"
                    )

                    available_dict = {}
                    available_files = []

                    for k, fname in enumerate(list_files):
                        t0 = int(fname[3:7])
                        t1 = int(fname[8:12])

                        tw_av = list(self.tw_to_seconds([t0, t1]))

                        available_files.append(tw_av)
                        available_dict.update({k: tw_av})

                    if available_dict == {}:
                        print(f"No available DV files for {pair}")
                    else:
                        print(f"Available DV for {pair}:")
                        pprint.pprint(available_dict)

                    return available_files

                else:
                    tw = np.squeeze(self.seconds_to_tw(tw))

                    loadpath = f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{expname}/data/{processing}/{pair[:2]}/{pair[-2:]}/dv/dv_{tw[0]:04d}-{tw[-1]:04d}.npz"

        loaded = np.load(loadpath, allow_pickle=True)

        pair = loaded["pair"]
        corr = loaded["corr"]
        value = loaded["value"]
        stretch_vector = loaded["stretch_vector"]
        time_windows = loaded["time_windows"]
        sim_mat = loaded["sim_mat"]
        sourcetimes = loaded["sourcetimes"]
        cluster_corrected = loaded["cluster_corrected"]

        return DV(
            expid=expid,
            pair=pair,
            corr=corr,
            value=value,
            stretch_vector=stretch_vector,
            time_windows=time_windows,
            sim_mat=sim_mat,
            sourcetimes=sourcetimes,
            cluster_corrected=cluster_corrected,
        )

    def plot_dv(
        self,
        dv,
        tw_id: int = 0,
        a: int = 4,
        r: float = 16 / 9,
        datelim: list = [None, None],
        ax=None,
    ):
        return mtsp.plot_dv(self, dv, tw_id=tw_id, a=a, r=r, datelim=datelim, ax=ax)

    def seconds_to_tw(self, times, micros: bool = True):
        """
        Create the list of indices to pass on as a time window for all the processing and computations. The trigger time is taken as the origin point for time.

        :param tstart: Starting time
        :type tstart: float

        :param tend: Ending time
        :type tend: float

        :param micros: Set to "True" to pass values as microseconds, False to pass as seconds. Defaults to True.
        :type micros: bool
        """

        tws = []

        times = np.atleast_2d(np.array(times))

        times.astype(object)

        for time in times:
            if time[0] > time[-1]:
                warn = "End time in the time window must be after the start time"
                raise ValueError(warn)

            tstart = time[0]
            tend = time[-1]

            srate = float(self.sampling_rate)

            twstart = int(tstart * srate * ((10**-6) ** micros) + self.trigger)
            twend = int(tend * srate * ((10**-6) ** micros) + self.trigger)

            tmax = (self.signal_length - 1 - self.trigger) * (
                ((10**6) ** micros) / srate
            )

            if twend >= self.signal_length:
                warn = (
                    "Ending time when creating the time window is later than end of signal. Latest possible time is {}".format(
                        tmax
                    )
                    + (1 - micros) * "s"
                    + micros * "μs"
                )
                raise ValueError(warn)

            tws.append(np.arange(twstart, twend + 1))

        return np.array(tws, dtype=object)

    def tw_to_seconds(self, tws: np.ndarray, micros: bool = True):
        """
        Convert a list of sample indices to the corresponding list in seconds (or microseconds). The sourcetime is set at the pretrigger sample number.

        :param tstart: Starting time
        :type tstart: float

        :param tend: Ending time
        :type tend: float

        :param micros: Set to "True" to pass values as microseconds, False to pass as seconds. Defaults to True.
        :type micros: bool
        """

        tws = np.atleast_2d(np.array(tws))

        times = []

        srate = float(self.sampling_rate)

        for tw in tws:
            if tw[0] > tw[-1]:
                warn = "End time in the time window must be after the start time"
                raise ValueError(warn)

            tstart = (tw[0] - self.trigger) * (((10**6) ** micros) / srate)
            tend = (tw[-1] - self.trigger) * (((10**6) ** micros) / srate)

            if micros:
                tstart = round(tstart, 1)
                tend = round(tend, 1)
            else:
                tstart = round(tstart, 7)
                tend = round(tend, 7)

            times.append([tstart, tend])

        return np.squeeze(times)


"""
---------------------------------------------------------------
                            DV CLASS
---------------------------------------------------------------
"""


class DV(object):
    """
    An object designed to contain the results from either relative
    velocity change or shift computations, done on ADP objects
    """

    def __init__(
        self,
        expid: int = None,
        pair: str = None,
        sourcetimes: np.ndarray = None,
        corr: np.ndarray = None,
        value: np.ndarray = None,
        stretch_vector: np.ndarray = None,
        time_windows: np.ndarray = None,
        sim_mat: np.ndarray = None,
        cluster_corrected: bool = False,
    ):
        self.expid = expid
        self.pair = pair

        self.corr = corr
        self.value = value
        self.stretch_vector = stretch_vector
        self.time_windows = time_windows
        self.sourcetimes = sourcetimes
        self.sim_mat = sim_mat
        self.cluster_corrected = cluster_corrected

    def __str__(self):
        max_corr = np.amax(self.corr)
        min_corr = np.amin(self.corr)

        stretch_range = 100 * np.amax(self.stretch_vector)

        tws = self.time_windows

        ltw = []

        for tw in tws:
            ltw = ltw + [tw[0], tw[-1]]

        out = f"Max corr. value: {max_corr} \nMin corr. value: {min_corr} \nStretch range: {stretch_range}% \nTime windows (samples): {ltw} \nCluster-corrected: {self.cluster_corrected} \n"

        if self.sim_mat != None:
            out_smat = "Similarity matrix included"
        else:
            out_smat = "Similarity matrix not included"

        return out + out_smat

    def save(self, processing: str = "nofilt", savepath: str = None, ow: bool = False):
        """
        Saves a DV object to the specified path in a compressed .npz format

        :param processing: A string containing the name of the chosen processing parameters for the analysis, defaults to 'nofilt'
        :type processing: str

        :param savepath: Saving path for file, .npz extension is added automatically if not present
        :type savepath: str

        :param ow: Set to True to overwrite an already existing file.
        :type ow: bool
        """

        # Assumes a single time window
        tw = self.time_windows.squeeze()

        expid = self.expid
        pair = str(self.pair)

        if savepath is None:
            with open("./cfg/experiments.yaml") as file:
                sparam = yaml.load(file, Loader=yaml.FullLoader)

            expname = sparam["exps"][expid]["expname"]

            dirpath = f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{expname}/data/{processing}/{pair[:2]}/{pair[-2:]}/dv"

            if not (os.path.isdir(dirpath)):
                os.makedirs(dirpath)

            savepath = f"{dirpath}/dv_{tw[0]:04d}-{tw[-1]:04d}.npz"

        if os.path.isfile(savepath) and not (ow):
            print(f"DV file already found. Set ow to True to overwrite.")

            return None

        corr = self.corr
        value = self.value
        stretch_vector = self.stretch_vector
        time_windows = self.time_windows
        sim_mat = self.sim_mat
        sourcetimes = self.sourcetimes
        cluster_corrected = self.cluster_corrected

        np.savez_compressed(
            savepath,
            expid=expid,
            pair=pair,
            corr=corr,
            value=value,
            stretch_vector=stretch_vector,
            time_windows=time_windows,
            sim_mat=sim_mat,
            sourcetimes=sourcetimes,
            cluster_corrected=cluster_corrected,
        )

        return None

    def correct_dv(
        self,
        ijs=None,
        tw_indices: np.ndarray = [160, 200],
        ow: bool = False,
        ow_indices: bool = False,
        inplace: bool = True,
    ):
        """
        A method to get rid of the bimodal distribution that sometimes appears on DV measurements. Clusters points using a running median calculation, and then averages out the upper and lower distributions. It also returns the differences between the upper and lower clusters once interpolated.

        :param ijs: The indices of the stress changes, excluding 0 but including the last measurement number
        :type ijs: list

        :param tw_indices: The time window used to compute the indices of the upper and lower clusters for a given source. Pick one where the bimodal effect is most apparent across different source-receiver combinations.
        :type tw_indices: np.ndarray

        :param recompute_indices: Boolean to determine whether or not to recompute the indices. Defaults to False.
        :type recompute_indices: bool

        :param inplace: Whether or not to replace the current DV values by the corrected ones. Defaults to True.
        :type inplace: bool

        :return dv_value: An array containing the corrected DV values
        :type dv_value: np.ndarray

        :return delta: An array containing the differences between the interpolated upper and lower clusters.
        :type delta: np.ndarray
        """

        if self.cluster_corrected and not (ow):
            print(
                "DV object was already corrected. No further correction was performed."
            )

            return None

        dv_value = deepcopy(self.value)

        pair = str(self.pair)
        source = int(pair[:2])
        expid = self.expid

        with open("./cfg/experiments.yaml") as file:
            sparam = yaml.load(file, Loader=yaml.FullLoader)

        expname = sparam["exps"][expid]["expname"]

        fname = f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{expname}/data/cluster_indices/indices_{source:02d}.pkl"

        # Load or compute indices

        if ow_indices:
            print(f"Computing indices from scratch for source {source:02d}")
            indices = mtsiu.determine_cluster_indices(
                source,
                tw_indices,
                fname,
                expid,
                ijs,
                ow_indices,
            )
        else:
            try:
                with open(fname, "rb") as fh:
                    indices = pickle.load(fh)

                print(f"Loading indices from file")
            except:
                print(
                    f"Could not find pre-computed indices, computing from scratch for source {source:02d}"
                )
                indices = mtsiu.determine_cluster_indices(
                    source, tw_indices, fname, expid, ijs, ow_indices
                )

        # Perform the actual cluster correction
        dv_value, delta = mtsiu.correct_cluster_indices(
            dv_value, indices["inds"], indices["indl"], ijs
        )

        # Relative measurements start at 0 with respect to the first measurement point
        dv_value = dv_value - dv_value[0]

        if inplace:
            self.value = dv_value
            self.cluster_corrected = True

        return dv_value, delta


"""
---------------------------------------------------------------
                    ACOUSTIC DATA PAIR CLASS
---------------------------------------------------------------
"""


class acousticDataPair(expHandler):
    """
    An object designed to contain the acoustic datasets and the relevant processing methods as well
    """

    def __init__(
        self,
        expid: int = None,
        pair: str = None,
        waveforms: np.ndarray = None,
        sourcetimes: np.ndarray = None,
        dv: DV = None,
        shifts: dict = None,
        was_corrected: bool = False,
        pulse_shifts: np.ndarray = None,
    ):
        """
        This creates an object designed to store acoustic measurements for a given sensor pair.

        :param waveforms: Array with all the acoustic waveforms for a given pair
        :type waveforms: np.ndarray

        :param starttimes: Array with absolute beginning date of each waveform
        :type starttimes: np.ndarray

        :param expname: String containing the name of the dataset/experiment
        :type expname: str

        :param daxdate: Beginning date of the acoustic measurements (given by the DAXbox)
        :type daxdate: str

        :param pair: Name of the sensor pair in the sender-receiver form
        :type pair: str

        :param signal_length: Length of each signal in samples. All signals in a same experiment are assumed to be of the same length
        :type signal_length: int

        :param sampling_rate: Sampling rate of each signal. The sampling rate is assumed to be the same for all signals in a same experiment
        :type sampling_rate: float
        """

        self.pair = pair

        if not (expid is None):
            # Load parameters from parameter file
            with open("./cfg/experiments.yaml") as file:
                sparam = yaml.load(file, Loader=yaml.FullLoader)

            self.exp = sparam["exps"][int(expid)]

            self.signal_length = int(self.exp["signal_length"])
            self.sampling_rate = float(self.exp["sampling_rate"])
            self.trigger = int(self.exp["trigger"])

            self.expname = self.exp["expname"]
            self.dax_start = self.exp["dax_start"]
            self.mts_start = self.exp["mts_start"]

        self.expid = expid

        self.waveforms = waveforms
        self.sourcetimes = sourcetimes

        self.dv = dv
        self.shifts = shifts

        self.was_corrected = was_corrected
        self.pulse_shifts = pulse_shifts

    def __str__(self):
        out1 = (
            f"Experiment: {self.expid} - {self.expname}\n"
            f"Pair: {self.pair}\n"
            f"Experiment date (acoustic data): {self.dax_start}\n"
            f"Corrected for pulse shifts: {self.was_corrected}"
        )
        if hasattr(self, "frequency_range"):
            out1 += f"\nFrequency range: {self.frequency_range}"

        if hasattr(self, "dv") and not (self.dv is None):
            out_dv = "\n\nTime windows for DV computations (μs):"

            tws = self.dv.time_windows

            twssec = self.tw_to_seconds(tws)

            for i in range(twssec.shape[0]):
                out_dv = out_dv + f"\n{i}: {twssec[i]}"

        else:
            out_dv = ""

        if hasattr(self, "shifts") and not (self.shifts is None):
            out_shifts = "\n\nTime windows for shifts computations (μs):"

            tws = self.shifts["time_windows"]

            twssec = self.tw_to_seconds(tws)

            for i in range(twssec.shape[0]):
                out_shifts = out_shifts + f"\n{i}: {twssec[i]}"

        else:
            out_shifts = ""

        return out1 + out_dv + out_shifts

    def compute_shifts(
        self,
        tws: np.ndarray,
        ref_trace_id: int = 0,
        shift_steps: int = 2001,
        shift_range_us: float = 1,
        return_sim_mat: bool = False,
    ) -> dict:
        shift_range = shift_range_us * 10  # Convert microseconds to samples

        tws_samples = self.seconds_to_tw(tws)

        shifts = mtss.compute_shifts_interpolate(
            self.waveforms,
            tws_samples,
            self.waveforms[ref_trace_id, :],
            shift_steps,
            shift_range,
            return_sim_mat,
        )

        self.shifts = shifts

        return None

    def compute_dv_stretch(
        self,
        tws: np.ndarray,
        ref_id: int = 0,
        stretch_range: float = 0.1,
        stretch_steps: int = 1001,
        return_sim_mat: bool = False,
    ) -> DV:
        """
        tws is given in microseconds and converted to samples later on

        :param stretch_range: width of the epsilon grid search, in percents
        :type stretch_range: float
        """

        stretch_range = 0.01 * stretch_range  # Convert percents to proportion

        tws = self.seconds_to_tw(tws)

        eps_steps = mtss.compute_eps_steps(stretch_range, stretch_steps)

        ref_trace = self.waveforms[ref_id, :]

        ref_traces = mtss.stretch_reference(ref_trace, eps_steps)

        dv_dict = mtss.compute_velocity_change_stretch(
            self.waveforms, tws, ref_traces, eps_steps, return_sim_mat
        )

        dv = DV(
            self.expid,
            self.pair,
            self.sourcetimes,
            dv_dict["corr"],
            dv_dict["value"],
            dv_dict["stretch_vector"],
            dv_dict["time_windows"],
            dv_dict["sim_mat"],
        )

        self.dv = dv

        return dv

    def correct_pulse_shifts_fourier(
        self, return_shifts: bool = False, alpha: float = 0.1
    ):
        """
        This function corrects the traces in the ADP waveforms by computing
        time shifts on the associated pulsing channel in the Fourier domain,
        and then applying those time shifts in the Fourier domain as well before
        doing an inverse transform.
        """

        if self.was_corrected:
            raise ValueError("This dataset was corrected for pulse-shifting already")

        rpair = self.pair[:2] + "-" + self.pair[:2]
        padp = create_adp(expid=self.expid, pair=rpair)

        Nwf = self.waveforms.shape[0]

        # Compute the parameters (dt, df, T, fsamp)

        fsamp = self.sampling_rate
        dt = 1 / fsamp

        # Create the tapering window around the pulse

        fftwin = np.arange(self.signal_length)

        Nfft = fftwin.shape[0]
        freqs = sf.fftfreq(Nfft, d=dt)

        taper_fun = ssig.windows.tukey(Nfft, alpha)

        # Compute the FFT for the reference pulse

        pref = padp.waveforms[0, :]

        red_pref = pref[fftwin] * taper_fun

        fpref = sf.fft(red_pref)

        # Loop over all the other pulses and compute the FFT for them.
        # Then, in each loop, correct the trace waveform by multiplying it by the exponential of the computed phase difference
        # The whole loop assumes Nfft = signal_length

        shifts = np.zeros(Nwf)

        for k in range(1, Nwf):
            pshift = padp.waveforms[k, :]
            red_pshift = pshift[fftwin] * taper_fun
            fpshift = sf.fft(red_pshift)

            # Compute the angles and phase difference

            dphi = np.angle(fpshift / fpref)  # Divide shifted/ref to obtain the shift
            dphi = np.unwrap(dphi)  # Unwrap to get rid of cycle skipping

            shifts[k] = sstats.linregress(
                2 * np.pi * freqs[: int(Nfft / 2)], dphi[: int(Nfft / 2)]
            )[0]

            # Correct the shifted waveform

            trace = self.waveforms[k, :]
            fft_trace = sf.fft(trace * taper_fun)

            expmdphi = np.cos(dphi) - np.sin(dphi) * 1j

            corr_fft_trace = fft_trace * expmdphi

            corr_trace = sf.ifft(corr_fft_trace)

            self.waveforms[k, :] = corr_trace

        self.was_corrected = True

        if return_shifts:
            return shifts
        else:
            return None

    def filter(self, freqs: list = [None, None], order: int = 3):
        """Filters the waveforms of the ADP object using a Butterworth filter of order 'order'.
        The filtering is done in-place, so the user should copy the object beforehand.

        :param freqs: A tuple containing the frequencies to be used for the filtering.
            If the first component is zero or None, the filter acts as a low-pass.
            If the second component is zero or None, it is a high-pass.
            If both are floats, then it acts as a bandpass filter.
        :type freqs: tuple

        :param order: The integer order to be used for the Butterworth filter.
        :type order: int
        """

        fsamp = self.sampling_rate

        self.waveforms = mtss.filter_wf(self.waveforms, fsamp, freqs, order)

        self.frequency_range = freqs

        return None

    def plot_dv(
        self,
        tw_id: int = 0,
        a: int = 4,
        r: float = 16 / 9,
        datelim: list = [None, None],
        ax=None,
    ):
        if not (hasattr(self, "dv")) or self.dv is None:
            raise AttributeError("No DV measurements found for this object")

        return mtsp.plot_dv(
            self, self.dv, tw_id=tw_id, a=a, r=r, datelim=datelim, ax=ax
        )

    def plot_shifts(
        self,
        tw_id: int = 0,
        a: int = 4,
        r: float = 16 / 9,
        datelim: list = [None, None],
        ax=None,
    ):
        """
        Plots the shifts, expressed in microseconds of delay compared to a reference trace
        """

        if not (hasattr(self, "shifts")) or self.shifts is None:
            raise AttributeError("No shifts found for this object")

        return mtsp.plot_shifts(self, tw_id=tw_id, a=a, r=r, datelim=datelim, ax=ax)

    def plot_correlation_coef(
        self,
        measurement_type: str = "shifts",
        tw_id: int = 0,
        a: int = 4,
        r: float = 16 / 9,
        datelim: list = [None, None],
        ax=None,
        sim_mat=False,
    ):
        """
        Plots the correlation coefficients
        """

        return mtsp.plot_correlation_coef(
            self,
            measurement_type=measurement_type,
            tw_id=tw_id,
            a=a,
            r=r,
            datelim=datelim,
            ax=ax,
            sim_mat=sim_mat,
        )

    def plot_waveform(
        self,
        waveform_id: np.ndarray = np.array([0]),
        tws: np.ndarray = None,
        a: int = 4,
        r: float = 16 / 9,
        tlim: list = [None, None],
        ax=None,
    ):
        """
        Plotting function for the acoustic trace for a given event ID.
        Defaults to the first trace.

        :param waveform_id: one or multiple waveform IDs that we want to plot
        :param waveform_id: np.ndarray

        :param tws: one or multiple time windows (as 1D or 2D arrays) to be displayed on
            the waveform, in microseconds
        :type tws: np.ndarray

        :param tlim: set the time bounds for the plot. Either bound can be set to None
        :type tlim: list

        :param ax: Axes object on which to plot the waveform. If no object is passed
            as an argument, a figure is created and the function does not return anything.
        :type ax: matplotlib.axes
        """

        return mtsp.plot_waveform(
            adp=self, waveform_id=waveform_id, tws=tws, a=a, r=r, ax=ax, tlim=tlim
        )

    def save(self, savepath: str = None, ow: bool = False, processing: str = "nofilt"):
        """
        Saves an acousticDataPair object to the specified path in a compressed .npz format

        :param savepath: Saving path for file, .npz extension is added automatically if not present
        :type savepath: str

        :param ow: Set to True to overwrite an already existing file.
        :type ow: bool

        :param processing: A string containing the name of the chosen processing parameters for the analysis, defaults to 'nofilt'
        :type processing: str
        """

        expid = self.expid
        pair = str(self.pair)

        if savepath is None:
            with open("./cfg/experiments.yaml") as file:
                sparam = yaml.load(file, Loader=yaml.FullLoader)

            expname = sparam["exps"][expid]["expname"]

            dirpath = f"/home/manuel/Documents/Code/mts-acoustic-data-analysis/exps/{expname}/data/{processing}/{pair[:2]}/{pair[-2:]}/adp"

            if not (os.path.isdir(dirpath)):
                os.makedirs(dirpath)

            savepath = f"{dirpath}/adp_{pair}.npz"

        if os.path.isfile(savepath) and not (ow):
            print(f"File already found at {savepath}. Set ow to True to overwrite.")

            return None

        waveforms = self.waveforms
        sourcetimes = self.sourcetimes

        was_corrected = self.was_corrected
        pulse_shifts = self.pulse_shifts

        np.savez_compressed(
            savepath,
            expid=expid,
            pair=pair,
            waveforms=waveforms,
            sourcetimes=sourcetimes,
            was_corrected=was_corrected,
            pulse_shifts=pulse_shifts,
        )

        return None


"""
---------------------------------------------------------------
                    GEOMETRY CLASS
---------------------------------------------------------------
"""


class geometry:
    def __init__(self, diameter=50, length=100, sensor_locations=None):
        """
        Create a geometry object that contains information on the geometry of the sample, assumed to be cylindrical, as well as information on the sensor locations and arrangement.

        :type diameter: float
        :param diameter: The sample's diameter, in millimeters

        :type length: float
        :param length: The sample's length (or height), in millimeters

        :type sensor_locations: dict
        :param sensor_locations: A dictionary containing information about the sensor locations, namely their name, azimuth (in degrees, clockwise) and z-coordinate (in millimeters). For example: {'01':{'az':135,'height':25},'02':{'az':0,'height':-50}}
        """

        self.geometry = {}
        self.geometry.update(
            {"diameter": diameter, "length": length, "radius": diameter / 2}
        )
        self.sensor_locs = sensor_locations

    def hor_dist(self, daz):
        """
        Calculate horizontal distance for given azimuth
        """
        diameter = self.geometry["diameter"]

        s = diameter * np.sin(daz / 180 * np.pi / 2)

        return s

    def path_length(self, pair, Nref=0, arc=0):
        """
        Calculate the length of the path between sensor s1 and sensor s2 for a given total number of reflections and path number

        :type Nref: int
        :param Nref: total number of reflection points

        :type arc: int
        :param arc: reflected path number (<= Nref)

        :type pair: str
        :param pair: chosen sensor pair
        """

        s1 = pair[:2]
        s2 = pair[-2:]

        az1 = self.sensor_locs[s1]["az"]
        az2 = self.sensor_locs[s2]["az"]

        taz, phi = self.path_azimuth(az1, az2, Nref, arc)

        hdist = (Nref + 1) * self.hor_dist(phi)

        z1 = self.sensor_locs[s1]["height"]
        z2 = self.sensor_locs[s2]["height"]

        vdist = np.abs(z2 - z1)

        path_length = np.sqrt(hdist**2 + vdist**2)

        r = self.geometry["radius"]

        cdist = r * np.abs(az1 - az2) * np.pi / 180

        surf_length = np.sqrt(cdist**2 + vdist**2)

        return path_length, surf_length

    def path_azimuth(self, az1, az2, Nref=0, arc=0):
        """
        Calculate the total azimuth of a wave (particle) traveled along a given
        path
        """
        assert arc <= Nref

        daz = az2 - az1

        # azimuth between sensors in the given direction
        az = (arc % 2) * 360 + (-1) ** (arc % 2) * np.abs(daz)

        # total azimuth travelled
        taz = az + np.floor(arc / 2) * 360

        # azimuth in between reflections
        phi = taz / (Nref + 1)

        if daz == 0:
            sign = 1
        else:
            sign = (-1) ** (arc % 2) * np.sign(daz)

        return sign * taz, sign * phi

    def compression_angle(self, pair, Nref=0, arc=0):
        """
        Returns alpha in degrees, where alpha is the angle
        between the vertical axis and the ray path, between 0° and 180°
        """

        pul = pair[:2]
        rec = pair[-2:]

        path_length = self.path_length(pair, Nref, arc)[0]

        z_pul = self.sensor_locs[pul]["height"]
        z_rec = self.sensor_locs[rec]["height"]

        delta_z = z_pul - z_rec

        alpha = np.arccos(np.abs(delta_z) / path_length) * 180 / np.pi

        return alpha

    def plot_circle(self, pair, Nref=0, arc=0):
        s1 = pair[:2]
        s2 = pair[-2:]

        r = self.geometry["radius"]

        az1 = self.sensor_locs[s1]["az"]
        az2 = self.sensor_locs[s2]["az"]

        # Getting the total and incremental azimuth of the reflection paths
        taz, phi = self.path_azimuth(az1, az2, Nref, arc)

        # Determining the Cartesian coordinates of the reflection points
        phis = np.arange(Nref + 2) * phi / 180 * np.pi
        xs = r * np.sin(phis + az1 / 180 * np.pi)
        ys = r * np.cos(phis + az1 / 180 * np.pi)

        # Angles for plotting the circle
        pphi = np.linspace(0, 2 * np.pi, 360)
        plt.plot(r * np.sin(pphi), r * np.cos(pphi))

        plt.plot(np.array(xs), np.array(ys))
        plt.gca().set_aspect("equal")

        return None

    def plot_ref_3d(self, pair, Nref=0, arc=0):
        s1 = pair[:2]
        s2 = pair[-2:]

        r = self.geometry["radius"]
        h = self.geometry["length"]

        az1 = self.sensor_locs[s1]["az"]
        az2 = self.sensor_locs[s2]["az"]

        # Plotting the plain cylinder according do the dimensions
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection="3d")

        Xc, Yc, Zc = data_for_cylinder_along_z(0, 0, h / 2, r, h)
        ax.plot_surface(Xc, Yc, Zc, alpha=0.25)

        # dirty trick to make the axes look equal
        ax.plot(
            [-h / 2, h / 2, h / 2], [-h / 2, -h / 2, h / 2], [0, 0, 0], color="none"
        )

        # Getting the total and incremental azimuth of the reflection paths
        taz, phi = self.path_azimuth(az1, az2, Nref, arc)

        # Determining the Cartesian coordinates of the reflection points
        z1 = self.sensor_locs[s1]["height"]
        z2 = self.sensor_locs[s2]["height"]

        phis = np.arange(Nref + 2) * phi / 180 * np.pi
        xs = r * np.sin(phis + az1 / 180 * np.pi)
        ys = r * np.cos(phis + az1 / 180 * np.pi)
        zs = np.linspace(z1, z2, Nref + 2)

        ax.plot3D(xs, ys, zs, color="black")

        color = np.tensordot(np.ones_like(xs), [0.5, 0.5, 0.5], axes=0)
        color[0] = [1, 0, 0]
        color[-1] = [0, 1, 0]

        dotsize = np.ones_like(xs) * 100
        dotsize[0] = 500
        dotsize[-1] = 500

        ax.scatter3D(xs, ys, zs, color=color, s=dotsize)

        return None

    def deform(self, eps, nu=0.2):
        """
        Returns the deformed geometry object for a given strain.

        :param eps: The applied strain, in relative units (not percentage). Rock mechanics conventions apply: eps > 0 means compression
        :type eps: float

        :param nu: Poisson ratio for the material, between 0 and 1/2
        :type nu: float
        """

        geom = self.geometry

        diameter = (1 + nu * eps) * geom["diameter"]

        length = (1 - eps) * geom["length"]

        gm_def = geometry(diameter, length)

        if hasattr(self, "sensor_locs"):
            sensor_locations = self.sensor_locs
            keys = sensor_locations.keys()

            for k in keys:
                sensor_locations[k]["height"] = sensor_locations[k]["height"] * (
                    1 - eps
                )

            gm_def.sensor_locs = sensor_locations

        return gm_def

    def velocity_change_deform(self, pair, eps, Nref=0, arc=0, nu=0.2):
        """
        Returns the relative velocity change induced by the sample geometry change during deformation. This means
        this value has to be subtracted from the computed DV/V to get rid of the apparent DV/V due to sample
        deformation.

        :param pair: The sensor pair for which we want to compute the relative velocity change
        :type pair: str

        :param eps: The applied strain, in relative units (not percentage). Taken positive for compression
        :type eps: float

        :param nu: Poisson ratio for the material, between 0 and 1/2
        :type nu: float
        """

        len_norm = self.path_length(pair, Nref, arc)[0]

        gm_def = self.deform(eps, nu)

        len_def = gm_def.path_length(pair, Nref, arc)[0]

        dvv_def = 1 - len_def / len_norm

        return dvv_def


"""
---------------------------------------------------------------
                    UTILS
---------------------------------------------------------------
"""


def create_adp(expid: int, pair: str):
    """
    Creates an acousticDataPair object from a specified experiment ID and a sensor pair.
    Signals must have been separated and exported as .mat files beforehand, and then
    stored at the path specified in the experiments.yaml config file.

    :param expid: Experiment ID as stated in the experiments.yaml file
    :type expid: int

    :param pair: Sensor pair in the '08-06' format
    :type pair: str

    :return: the corresponding acousticDataPair object
    :rtype: acousticDataPair
    """

    adp = acousticDataPair()

    adp.pair = pair
    adp.expid = expid

    # Load parameters from parameter file
    with open("./cfg/experiments.yaml") as file:
        sparam = yaml.load(file, Loader=yaml.FullLoader)

    adp.exp = sparam["exps"][expid]

    adp.signal_length = int(adp.exp["signal_length"])
    adp.sampling_rate = float(adp.exp["sampling_rate"])
    adp.trigger = int(adp.exp["trigger"])

    adp.expname = adp.exp["expname"]
    adp.dax_start = adp.exp["dax_start"]
    adp.mts_start = adp.exp["mts_start"]

    # Load the waveforms into the class
    matpath = adp.exp["path_matfiles"]

    matcont = matlab.loadmat(matpath + f"/wf_{pair}.mat", squeeze_me=True)

    # Retrieve measurement times
    adp.sourcetimes = np.array(
        [
            UTCDateTime(matcont["stats"]["date"].item()) + thistime
            for thistime in matcont["sourcetime"]
        ]
    )

    # Retrieve the waveforms and transpose so that different measurements are along the rows
    adp.waveforms = np.transpose(matcont["mat"])

    return adp


def running_mean(x, N):
    """
    A simple running mean function to smooth curves and plots. The returned array will be N points shorter than the original one

    :param x: The input 1D array to be smoothed
    :type x: np.ndarray

    :param N: The number of data points to smooth the data over on each side, N samples before and N samples after.
    :type N: int
    """

    cumsum = np.cumsum(np.insert(x, 0, 0))
    return (cumsum[N:] - cumsum[:-N]) / float(N)


def uncert_format(a, b):
    """
    Takes two floats a and b and displays them as (a ± b)e±p, where p is the largest of the two orders of magnitude.

    This helps compare both numbers a bit easier as they are scaled up to the same power of 10.

    :param a: The measured value
    :type a: float

    :param b: The standard error associated with the measure, usually positive
    :type b: float

    :return str_uncert: The formatted uncertainty string
    :type str_uncert: str
    """

    pa = np.floor(np.log10(np.abs(a)))

    fma = a * 10**-pa
    fmb = b * 10**-pa

    str_uncert = f"( {fma:.3f} ± {fmb:.3f} )e{int(pa)}"

    return str_uncert


def sensor_map_to_physical_coord(
    transducer_coord: list = [],
    delta_phi: float = np.pi / 4,
    delta_z: float = 10,
    ref_transducer: int = 8,
    plot: bool = False,
) -> list:
    """
    Converts the integer transducer coordinates from the sensor map to
    actual physical coordinates for sensors on the surface of a cylinder,
    independent of radius.

    :param transducer_coord: The list of integer (i,j) coordinates from
        the sensor map, where i is the vertical index and j the horizontal
        index (matrix convention)
    :type transducer_coord: list

    :param delta_phi: Angle step on the sensor map, in radians
    :type delta_phi: float

    :param delta_z: Vertical height step on the sensor map, in mm
    :type delta_z: float

    :param ref_transducer: The transducer that will serve as reference
        for the vertical and angle coordinates
    :type ref_transducer: int

    :param plot: Plots the new sensor map with the transducer numbers.
        Defaults to False, set to True to plot
    :type plot: bool

    :returns transducer_phys_coord: A list with the physical cylindrical
        coordinates of the transducers, in the form of (phi, z), independent
        of radius.
    :type transducer_phys_coord: list
    """

    i0 = transducer_coord[ref_transducer - 1][0]
    j0 = transducer_coord[ref_transducer - 1][1]

    transducer_phys_coord = []

    for k, trd in enumerate(transducer_coord):
        i = trd[0]
        j = trd[1]

        rcoord = ((j - j0) * delta_phi, (i0 - i) * delta_z)

        transducer_phys_coord.append(rcoord)

    if plot:
        for k, trd in enumerate(transducer_phys_coord):
            plt.scatter(trd[0] * 180 / np.pi, trd[1])
            ax = plt.gca()
            ax.annotate(f"{k+1}", (trd[0] * 180 / np.pi, trd[1]))

    return transducer_phys_coord


def data_for_cylinder_along_z(center_x, center_y, center_z, radius, height_z):
    z = np.linspace(0, height_z, 50) - center_z
    theta = np.linspace(0, 2 * np.pi, 50)
    theta_grid, z_grid = np.meshgrid(theta, z)
    x_grid = radius * np.cos(theta_grid) + center_x
    y_grid = radius * np.sin(theta_grid) + center_y

    return x_grid, y_grid, z_grid
