import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import interp1d
from importlib import reload
from copy import deepcopy

from collections import deque
from bisect import insort, bisect_left
from itertools import islice

import mts_cwi.mts_data as mtsd


def demean(trc):
    return trc - np.mean(trc)


def get_offset(indicesA, indicesB):
    trA = make_ind_trace(indicesA)
    trB = make_ind_trace(indicesB)
    corr = np.correlate(trA, trB, mode="full")
    return corr


def make_ind_trace(indices):
    tr = np.zeros(np.sum([len(indices[ind]) for ind in ["inds", "indl", "indn"]]))
    tr[indices["indl"]] = 1
    tr[indices["inds"]] = -1
    return tr


def combine_ind_joint(indslist):
    lim = len(indslist) * 0.2
    tr = make_ind_trace(indslist[0])

    for inds in indslist[1:]:
        tr += make_ind_trace(inds)

    # plt.plot(tr, ".")

    cindl = np.where(tr > lim)[0]
    cinds = np.where(tr < -lim)[0]
    cindn = np.array(
        list(set(np.where(tr < lim)[0]).intersection(set(np.where(tr > -lim)[0])))
    )

    return {"inds": cinds, "indl": cindl, "indn": cindn}


def combine_ind(ind1, ind2):
    inds1 = ind1["inds"]
    indl1 = ind1["indl"]
    indn1 = ind1["indn"]
    inds2 = ind2["inds"]
    indl2 = ind2["indl"]
    indn2 = ind2["indn"]

    ses1 = set(inds1)
    sel1 = set(indl1)
    sen1 = set(indn1)
    ses2 = set(inds2)
    sel2 = set(indl2)
    sen2 = set(indn2)

    r1 = len(ses1.union(ses2)) / (len(inds1) + len(inds2))
    r2 = len(ses1.union(sel2)) / (len(inds1) + len(indl2))
    if r1 > r2:  # inds1 corresponds to indl2
        print("exchange")
        tmp = deepcopy(ses2)
        ses2 = deepcopy(sel2)
        sel2 = tmp
    sew = ses1.intersection(sel2)
    ses = ses1.intersection(ses2)
    sel = sel1.intersection(sel2)
    sel = sel.union(sen1.intersection(sel2))
    sel = sel.union(sen2.intersection(sel1))
    ses = ses.union(sen1.intersection(ses2))
    ses = ses.union(sen2.intersection(ses1))
    sen = ses1.union(sel1, sen1).difference(ses.union(sel))
    sew = ses1.intersection(sel2)

    inds = list(set(list(inds1) + list(inds2)))
    indl = list(set(list(indl1) + list(indl2)))
    return {"inds": list(ses), "indl": list(sel), "indn": list(sen)}


def dezip(seq):
    medsizes = [3, 10, 30, 100, 300]
    groupsind = []
    for meds in medsizes:
        groupsind.append(find_lims(seq, meds))
    indsum = np.sum(groupsind, axis=0)
    inds = np.where(indsum > 0)[0]
    indl = np.where(indsum < 0)[0]
    indn = np.where(indsum == 0)[0]
    return indsum, inds, indl, indn


def find_lims(seq, meds):
    # print(meds)
    rmedr = running_median_insort(seq, meds)
    rmedl = running_median_insort(seq[-1::-1], meds)[-1::-1]
    minlim = np.max(np.array((rmedr, rmedl)), axis=0)
    maxlim = np.min(np.array((rmedr, rmedl)), axis=0)
    indl = np.where(np.logical_and(seq > minlim, seq > maxlim))[0]
    inds = np.where(np.logical_and(seq < maxlim, seq < minlim))[0]
    groups = np.zeros_like(seq)
    groups[indl] += 1
    groups[inds] -= 1
    return groups


def running_median_insort(seq, window_size):
    """Contributed by Peter Otten"""
    seq = iter(seq)
    d = deque()
    s = []
    result = []
    for item in islice(seq, window_size):
        d.append(item)
        insort(s, item)
        result.append(s[len(d) // 2])
    m = window_size // 2
    for item in seq:
        old = d.popleft()
        d.append(item)
        del s[bisect_left(s, old)]
        insort(s, item)
        result.append(s[m])
    return result


def remove_all_occurrences(elem, arr):
    if type(arr) == np.ndarray:
        arr = list(arr)

    removed = True

    while removed:
        try:
            arr.remove(elem)

        except:
            removed = False

    arr = np.array(arr)

    return arr


def correct_cluster_indices(
    dv_value,
    inds,
    indl,
    ijs: list = None,
):
    dv = deepcopy(dv_value)

    if not (ijs is None):
        ijs = np.append([0], ijs)
        print("Performing piecewise correction")
    else:
        ijs = np.array([0, dv.shape[0]])

    indfull = np.arange(dv.shape[0])
    dv_averaged = np.zeros_like(dv)
    delta = np.zeros_like(dv)

    for k in range(ijs.shape[0] - 1):

        ij0 = int(ijs[k])
        ij1 = int(ijs[k + 1])

        # Find values of pinds between ij0 and ij1
        pinds = inds[np.where(np.logical_and(inds >= ij0, inds < ij1))]
        pindl = indl[np.where(np.logical_and(indl >= ij0, indl < ij1))]

        dvs = dv[pinds]
        dvl = dv[pindl]

        interp_dvs = interp1d(
            pinds, dvs, bounds_error=False, fill_value=(dvs[0], dvs[-1])
        )
        interp_dvl = interp1d(
            pindl, dvl, bounds_error=False, fill_value=(dvl[0], dvl[-1])
        )

        full_dvs = interp_dvs(indfull[ij0:ij1])
        full_dvl = interp_dvl(indfull[ij0:ij1])

        dv_averaged[ij0:ij1] = (full_dvs + full_dvl) / 2
        delta[ij0:ij1] = full_dvs - full_dvl

    return (dv_averaged, delta)


def determine_cluster_indices(
    source: int,
    tw: np.ndarray,
    fname: str,
    expid: int = 12,
    ijs: list = None,
    ow: bool = False,
):
    """
    ijs: indices for the first sample after the stress change, if applicable.
    Defaults to None.
    """

    indices = {}
    dvv_dict = {}

    receivers = [r for r in range(1, 15) if r != source]

    dvv_dict.update({source: {}})
    indices.update({source: {}})

    print(f"Source: {source:02d}")
    print(f"Time window: {tw}")

    eh = mtsd.expHandler(expid)

    if not (ijs is None):
        ijs = np.append([0], ijs)
        print("Performing piecewise bimodal correction")

    for receiver in receivers:
        pair = f"{source:02d}-{receiver:02d}"

        try:
            print(f"Pair {pair}")

            adp = eh.load_adp(pair)

            adp.compute_dv_stretch(tws=tw)

            dvv = adp.dv.value
            dvv_dict[source].update({receiver: dvv})

            if ijs is None:
                ijs = np.array([0, dvv.shape[0]])

            inds = np.empty(0, int)
            indl = np.empty(0, int)
            indn = np.empty(0, int)

            for k in range(ijs.shape[0] - 1):

                ij0 = int(ijs[k])
                ij1 = int(ijs[k + 1])

                pindsum, pinds, pindl, pindn = dezip(dvv[ij0:ij1])

                pinds += ij0
                pindl += ij0
                pindn += ij0

                inds = np.append(inds, pinds)
                indl = np.append(indl, pindl)
                indn = np.append(indn, pindn)

            # find indices of two groups
            # indsum, inds, indl, indn = dezip(dvv)

            indices[source].update(
                {
                    receiver: {
                        "inds": inds,
                        "indl": indl,
                        "indn": indn,
                    }
                }
            )

        except:
            print(f"Failed for {pair}")

    indx = [indices[source][ind] for ind in indices[source].keys()]
    cindj = combine_ind_joint(indx)

    if ow or not (os.path.isfile(fname)):
        with open(fname, "wb+") as fh:
            pickle.dump(cindj, fh)
        print("Indices file saved")
    else:
        print(
            "Indices were computed but not saved. Set ow to True to overwrite existing indices"
        )

    return cindj
