# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
import shutil
from tqdm import tqdm
from scipy import stats
from numpy import linalg as LA
from numpy.linalg import matrix_rank
import sys
from numpy.linalg import norm
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from matplotlib import animation
import numba
from scipy.stats import ortho_group
from mpl_toolkits.mplot3d import Axes3D
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
          'axes.labelsize': 10,
          'text.fontsize': 10,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': False,
          'figure.figsize': [10 / 2.54, 6 / 2.54]}
N = 800#1000#0 #network size
np.random.seed(0)

def derivative(x, f):
    '''
        Compute the derivative of a time serie
        Used for jPCA
    '''
    from scipy.stats import linregress
    fish = np.zeros(len(f))
    slopes_ = []
    tmpf = np.hstack((f[0],f,f[-1])) # not circular
    binsize = x[1]-x[0]
    tmpx = np.hstack((np.array([x[0]-binsize]),x,np.array([x[-1]+binsize])))
    # plot(tmpx, tmpf, 'o')
    # plot(x, f, '+')
    for i in range(len(f)):
        slope, intercept, r_value, p_value, std_err = linregress(tmpx[i:i+3], tmpf[i:i+3])
        slopes_.append(slope)
        # plot(tmpx[i:i+3], tmpx[i:i+3]*slope+intercept, '-')
    return np.array(slopes_)/binsize

def buildHMap(n, ):
    '''
        build the H mapping for a given n
        used for the jPCA
    '''
    from scipy.sparse import lil_matrix
    M = np.zeros((n,n), dtype = np.int)
    M[np.triu_indices(n,1)] = np.arange(1,int(n*(n-1)/2)+1)
    M = M - M.transpose()
    m = np.vstack(M.reshape(n*n))
    k = np.vstack(M[np.triu_indices(n,1)]).astype('int')
    H = lil_matrix( (len(m), len(k)), dtype = np.float16)
    H = np.zeros( (len(m), len(k) ))
    # first column
    for i in k.flatten():
        # positive
        H[np.where(m == i)[0][0],i-1] = 1.0
        # negative
        H[np.where(m == -i)[0][0],i-1] = -1.0
    return H

eps = 10**(-3)*0.5#*0.1#*0.1
## learning rule
@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning(w, g_V_star, PSP_star, g_V_som):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(g_V_star[i]) + g_V_som[i]) * PSP_star[l]#-1*0.1*w[i,l]
            w[i, l] += eps * delta
            w[i,l] *= mask[i,l]

    return w



@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning_readout(w, y, r, z):

    for i in numba.prange(N):
            delta = (-(y) + z) * r[i]#-1*0.1*w[i,l]
            w[i] += eps * delta


    return w
dt = 1
tau = 10
M = np.zeros((N, N))
p=1
g=0.5
scale = 1.0 / np.sqrt(p * N)
mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if np.random.rand() < p:
            M[i, j] = (np.random.randn()) * g * scale
            mask[i,j] = 1


win1 = 6*(np.random.rand(N)-0.5)/np.sqrt(3)
win2 = 6*(np.random.rand(N)-0.5)/np.sqrt(3)
win3 = 6*(np.random.rand(N)-0.5)/np.sqrt(3)

g_L=1/tau
g_d=0.7

x = np.random.randn(N)
r = np.tanh(x)

M_chaos = np.zeros((N,N))
M_chaos2 = np.zeros((N,N))
scale_chaos = 1.0 / np.sqrt(0.1* N)

for i in range(N):
    for j in range(N):
        if np.random.rand() < 0.1:
            M_chaos[i, j] = (np.random.randn()) * 1.2 * scale_chaos

        if np.random.rand() < 0.5:
            M_chaos2[i, j] = (np.random.randn()) * 2 / np.sqrt(0.5* N)
z1_list = []
z2_list = []

prediction = np.zeros(N)
r_prediction = np.zeros(N)

a=1
b=1/3

M_std_list = []
M_chaos_std_list = []
x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)

simtime_len = 15000*1000
chaos_std = np.std(M_chaos)
w_chaos = np.random.randn(N)/np.sqrt(N)
tau_chaos = 10

w_readout1 = np.random.randn(N)/np.sqrt(N)#*0.2
w_readout2 = np.random.randn(N)/np.sqrt(N)#*0.2
w_readout3 = np.random.randn(N)/np.sqrt(N)#*0.2
input_low_passed = np.zeros(N)
y1 = 0
y2 = 0
y3 = 0

#output_list_x = np.zeros(simtime_len)
#output_list_y = np.zeros(simtime_len)
#output_list_z = np.zeros(simtime_len)
#target_list_x = np.zeros(simtime_len)
#target_list_y = np.zeros(simtime_len)
#target_list_z = np.zeros(simtime_len)
r_list = np.zeros((10,simtime_len))
x_L, y_L, z_L = (0., 1., 1.05)
input_list = np.zeros(simtime_len)

first_term = np.zeros(simtime_len)
second_term = np.zeros(simtime_len)
y_list_learning = np.zeros((3,50*1000))
target_list_learning = np.zeros((3,50*1000))
for i in tqdm(range(simtime_len), desc="[training]"):

    dot_x_L = 10 * (y_L - x_L)
    dot_y_L = 28 * x_L - y_L - x_L * z_L
    dot_z_L = x_L * y_L - 8/3 * z_L

    x_L += dot_x_L*0.007
    y_L += dot_y_L*0.007
    z_L += dot_z_L*0.007

    M_term = np.dot(M,r)
    Chaos_term = np.dot(M_chaos,r)

    x = (1.0 - dt / tau) * x + (M_term + Chaos_term)/tau

    #r_prediction = (1.0 - dt / tau) * r_prediction + (r) / tau

    #input_low_passed = (1.0 - dt / tau) * input_low_passed + (win1 * y1+win2 * y2+win3 * y3) / tau

    M = learning(M, M_term, r,(Chaos_term+win1 * y1+win2 * y2+win3 * y3))

    r = np.tanh(x)
    y1 = np.dot(w_readout1,r)
    y2 = np.dot(w_readout2,r)
    y3 = np.dot(w_readout3,r)

    if i>simtime_len-50*1000-1:
        y_list_learning[0,i-(simtime_len-50*1000)] = y1
        y_list_learning[1,i-(simtime_len-50*1000)] = y2
        y_list_learning[2,i-(simtime_len-50*1000)] = y3
        target_list_learning[0,i-(simtime_len-50*1000)] = x_L/10
        target_list_learning[1,i-(simtime_len-50*1000)] = y_L/10
        target_list_learning[2,i-(simtime_len-50*1000)] = z_L/10
    #prediction = (1.0 - dt / tau) * prediction + (np.dot(M,r))/tau


    #input_list[i] = input_low_passed[0]

    #M*=mask

    w_readout1 = learning_readout(w_readout1, y1, r, x_L*0.1)
    w_readout2 = learning_readout(w_readout2, y2, r, y_L*0.1)
    w_readout3 = learning_readout(w_readout3, y3, r, z_L*0.1)
np.savetxt("lorenz_no_filter_learning2.txt", y_list_learning)
np.savetxt("lorenz_target_learning2.txt", target_list_learning)
"""
pl.figure()
pl.plot(target_list_x)
pl.plot(target_list_y)
pl.plot(target_list_z)
pl.show()
"""
#pl.figure()
#pl.plot(target_list_x)
#pl.plot(output_list_x)
#pl.show()
"""
fig = plt.figure(figsize=(4,4))
ax1 = fig.add_subplot(1, 1, 1,projection="3d")
ax1.set_box_aspect((1, 1, 1))

pl.plot(target_list_x,target_list_y,target_list_z,c="steelblue",alpha = 0.8,lw=1)


pl.show()


fig = plt.figure(figsize=(4,4))
ax1 = fig.add_subplot(1, 1, 1,projection="3d")
ax1.set_box_aspect((1, 1, 1))

pl.plot(output_list_x,output_list_y,output_list_z,c="steelblue",alpha = 0.8,lw=1)


pl.show()
"""
##x = np.random.randn(N)
##r = np.tanh(x)

x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)
Q = np.random.randn(N,N)/np.sqrt(N)

simtime_len = 50*1000
r_list = np.zeros((N,simtime_len))

z1_list = np.zeros(simtime_len)
z2_list = np.zeros(simtime_len)
z3_list = np.zeros(simtime_len)
y_list = np.zeros((3,simtime_len))
target_list = np.zeros((3,simtime_len))

for i in tqdm(range(simtime_len), desc="[testing]"):
    dot_x_L = 10 * (y_L - x_L)
    dot_y_L = 28 * x_L - y_L - x_L * z_L
    dot_z_L = x_L * y_L - 8/3 * z_L

    x_L += dot_x_L*0.007
    y_L += dot_y_L*0.007
    z_L += dot_z_L*0.007

    target_list[0,i] = x_L/10
    target_list[1,i] = y_L/10
    target_list[2,i] = z_L/10

    x = (1.0 - dt / tau) * x + (np.dot(M+M_chaos, r)) / tau
    r = np.tanh(x)
    y1 = np.dot(w_readout1,r)
    y2 = np.dot(w_readout2,r)
    y3 = np.dot(w_readout3,r)

    y_list[0,i] = y1
    y_list[1,i] = y2
    y_list[2,i] = y3
"""
pl.figure()
pl.plot(y_list1)
pl.plot(y_list2)
pl.plot(y_list3)
pl.show()
"""
fig = plt.figure(figsize=(4,4))
ax1 = fig.add_subplot(1, 1, 1,projection="3d")
ax1.set_box_aspect((1, 1, 1))

pl.plot(y_list[0,:],y_list[1,:],y_list[2,:],c="steelblue",alpha = 0.8,lw=1)

pl.show()

