# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
import shutil
from tqdm import tqdm
from scipy import stats
from numpy import linalg as LA
from numpy.linalg import matrix_rank
import sys
from numpy.linalg import norm
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from matplotlib import animation
import numba
from scipy import signal
from scipy.stats import ortho_group
from mpl_toolkits.mplot3d import Axes3D
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
          'axes.labelsize': 10,
          'text.fontsize': 10,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': False,
          'figure.figsize': [10 / 2.54, 6 / 2.54]}
N = 500 #network size
#np.random.seed(0)

def derivative(x, f):
    '''
        Compute the derivative of a time serie
        Used for jPCA
    '''
    from scipy.stats import linregress
    fish = np.zeros(len(f))
    slopes_ = []
    tmpf = np.hstack((f[0],f,f[-1])) # not circular
    binsize = x[1]-x[0]
    tmpx = np.hstack((np.array([x[0]-binsize]),x,np.array([x[-1]+binsize])))
    # plot(tmpx, tmpf, 'o')
    # plot(x, f, '+')
    for i in range(len(f)):
        slope, intercept, r_value, p_value, std_err = linregress(tmpx[i:i+3], tmpf[i:i+3])
        slopes_.append(slope)
        # plot(tmpx[i:i+3], tmpx[i:i+3]*slope+intercept, '-')
    return np.array(slopes_)/binsize

def buildHMap(n, ):
    '''
        build the H mapping for a given n
        used for the jPCA
    '''
    from scipy.sparse import lil_matrix
    M = np.zeros((n,n), dtype = np.int)
    M[np.triu_indices(n,1)] = np.arange(1,int(n*(n-1)/2)+1)
    M = M - M.transpose()
    m = np.vstack(M.reshape(n*n))
    k = np.vstack(M[np.triu_indices(n,1)]).astype('int')
    H = lil_matrix( (len(m), len(k)), dtype = np.float16)
    H = np.zeros( (len(m), len(k) ))
    # first column
    for i in k.flatten():
        # positive
        H[np.where(m == i)[0][0],i-1] = 1.0
        # negative
        H[np.where(m == -i)[0][0],i-1] = -1.0
    return H

eps = 10**(-3)#*0.5
## learning rule
@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning(w, g_V_star, PSP_star, g_V_som):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(g_V_star[i]) + g_V_som[i]) * PSP_star[l]#-1*0.1*w[i,l]
            w[i, l] += eps * delta
            w[i,l] *= mask[i,l]

    return w



@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning_readout(w, y, r, z):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(y[i]) + z[i]) * r[l]#-1*0.1*w[i,l]
            w[i,l] += eps * delta


    return w
dt = 1
tau = 10
M = np.zeros((N, N))
p=1
g=0.5
scale = 1.0 / np.sqrt(p * N)
mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if np.random.rand() < p:
            M[i, j] = (np.random.randn()) * g * scale
            mask[i,j] = 1

q, r = np.linalg.qr(2*(np.random.rand(N,3)-0.5))
#print(np.dot(q[:,0],q[:,0]))

n_tar = 5
#Win = ortho_group.rvs(dim=N)*np.sqrt(N)*2
win = 3*2*(np.random.rand(N,n_tar)-0.5)/np.sqrt(n_tar)#*2*(np.random.rand()<1)*0.5

g_L=1/tau
g_d=0.7

x = np.random.randn(N)
r = np.tanh(x)

M_chaos = np.zeros((N,N))
M_chaos2 = np.zeros((N,N))
scale_chaos = 1.0 / np.sqrt(0.1* N)
for i in range(N):
    for j in range(N):
        if np.random.rand() < 0.1:
            M_chaos[i, j] = (np.random.randn()) * 1.2 * scale_chaos
        if np.random.rand() < 0.5:
            M_chaos2[i, j] = (np.random.randn()) * 2 / np.sqrt(0.5* N)
z1_list = []
z2_list = []

prediction = np.zeros(N)
r_prediction = np.zeros(N)

a=1
b=1/3

M_std_list = []
M_chaos_std_list = []
x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)
#M_tot_std_list = []
simtime_len = 100*1000#400*1000
chaos_std = np.std(M_chaos)
w_chaos = np.random.randn(N)/np.sqrt(N)
tau_chaos = 10

w_readout = np.random.randn(n_tar,N)/np.sqrt(N)#*0.2

input_low_passed = np.zeros(N)
y = np.zeros(n_tar)
z = np.zeros(n_tar)
mean_error_list = np.zeros(simtime_len)
for i in tqdm(range(simtime_len), desc="[training]"):
    z[0] = 1*np.sin(0.5*1*30 * i / (50 * 2 * np.pi))#+1*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))#+1*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))#+1*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))
    z[1] = 1 * np.sin(0.25 * 1 * 30 * i / (50 * 2 * np.pi))
    z[2] = 1*np.sin(0.5*1*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))
    z[3] = 1*np.sin(0.5*1*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))
    z[4] = 1*np.sin(0.5*1*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))
    #/(0.00000001+abs(1*np.sin(0.5*1*30 * i / (50 * 2 * np.pi))))
    M_term = np.dot(M, r)
    Chaos_term = np.dot(M_chaos, r)
    rec_term = Chaos_term + M_term
    x = (1.0 - dt / tau) * x + (rec_term) / tau

    #r_prediction = (1.0 - dt / tau) * r_prediction + (r) / tau
    #prediction = np.dot(M, r)
    #input_low_passed = (1.0 - dt / tau) * input_low_passed + (np.dot(win , y) ) / tau
    r = np.tanh(x)
    FB = (np.dot(win, y))
    y = np.dot(w_readout, r)

    mean_error_list[i] = np.sqrt(np.dot(y-z,y-z))

    # prediction = (1.0 - dt / tau) * prediction + (np.dot(M,r))/tau


    # input_list[i] = input_low_passed[0]
    M = learning(M, M_term, r, FB+Chaos_term)
    # M*=mask

    w_readout = learning_readout(w_readout, y, r, z)

    #z1_list.append(z1)
    #z2_list.append(z2)


#x = np.random.randn(N)
#r = np.tanh(x)

x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)
Q = np.random.randn(N,N)/np.sqrt(N)

simtime_len = 2*1000
r_list = np.zeros((N,simtime_len))

z1_list = np.zeros(simtime_len)
z2_list = np.zeros(simtime_len)
z3_list = np.zeros(simtime_len)
y_list = np.zeros((n_tar,simtime_len))
for i in tqdm(range(simtime_len), desc="[testing]"):
    x = (1.0 - dt / tau) * x + (np.dot(M+M_chaos, r)) / tau
    r = np.tanh(x)
    y = np.dot(w_readout,r)

    y_list[:,i] = y


fig, axs = plt.subplots(n_tar)
for i in range(n_tar):
    axs[i].plot(y_list[i,1000:])
plt.savefig('multiple_readouts.pdf', format='pdf',dpi=350)
#np.savetxt("readouts_multi.txt",y_list)

