# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
import shutil
from tqdm import tqdm
from scipy import stats
from numpy import linalg as LA
from numpy.linalg import matrix_rank
import sys
from numpy.linalg import norm
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from matplotlib import animation
import numba
from scipy.stats import ortho_group
from mpl_toolkits.mplot3d import Axes3D
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
          'axes.labelsize': 10,
          'text.fontsize': 10,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': False,
          'figure.figsize': [10 / 2.54, 6 / 2.54]}
N = 800 #network size
np.random.seed(4)

def derivative(x, f):
    '''
        Compute the derivative of a time serie
        Used for jPCA
    '''
    from scipy.stats import linregress
    fish = np.zeros(len(f))
    slopes_ = []
    tmpf = np.hstack((f[0],f,f[-1])) # not circular
    binsize = x[1]-x[0]
    tmpx = np.hstack((np.array([x[0]-binsize]),x,np.array([x[-1]+binsize])))
    # plot(tmpx, tmpf, 'o')
    # plot(x, f, '+')
    for i in range(len(f)):
        slope, intercept, r_value, p_value, std_err = linregress(tmpx[i:i+3], tmpf[i:i+3])
        slopes_.append(slope)
        # plot(tmpx[i:i+3], tmpx[i:i+3]*slope+intercept, '-')
    return np.array(slopes_)/binsize

def buildHMap(n, ):
    '''
        build the H mapping for a given n
        used for the jPCA
    '''
    from scipy.sparse import lil_matrix
    M = np.zeros((n,n), dtype = np.int)
    M[np.triu_indices(n,1)] = np.arange(1,int(n*(n-1)/2)+1)
    M = M - M.transpose()
    m = np.vstack(M.reshape(n*n))
    k = np.vstack(M[np.triu_indices(n,1)]).astype('int')
    H = lil_matrix( (len(m), len(k)), dtype = np.float16)
    H = np.zeros( (len(m), len(k) ))
    # first column
    for i in k.flatten():
        # positive
        H[np.where(m == i)[0][0],i-1] = 1.0
        # negative
        H[np.where(m == -i)[0][0],i-1] = -1.0
    return H

eps = 10**(-3)#*0.5
## learning rule
@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning(w, g_V_star, PSP_star, g_V_som):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(g_V_star[i]) + g_V_som[i]) * PSP_star[l]#-1*0.1*w[i,l]
            w[i, l] += eps * delta
            w[i,l] *= mask[i,l]

    return w

@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning_readout(w, y, r, z):

    for i in numba.prange(N):
            delta = (-(y) + z) * r[i]
            w[i] += eps * delta
    return w

dt = 1
tau = 10
M = np.zeros((N, N))
p=1
g=0.5
scale = 1.0 / np.sqrt(p * N)
mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if np.random.rand() < p:
            M[i, j] = (np.random.randn()) * g * scale
            mask[i,j] = 1

q, r = np.linalg.qr(2*(np.random.rand(N,3)-0.5))

#Win = ortho_group.rvs(dim=N)*np.sqrt(N)*2
win = 3*2*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5


w_control = 1*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5
w_control2 = 1*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5
w_control3 = 1*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5
w_control4 = 1*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5
g_L=1/tau
g_d=0.7

x01 = np.random.randn(N)
r01 = np.tanh(x01)

x02 = np.random.randn(N)
r02 = np.tanh(x02)

x=x01
r=r01

M_chaos = np.zeros((N,N))
M_chaos2 = np.zeros((N,N))
scale_chaos = 1.0 / np.sqrt(0.1* N)
for i in range(N):
    for j in range(N):
        if np.random.rand() < 0.1:
            M_chaos[i, j] = (np.random.randn()) * 1.2 * scale_chaos
        if np.random.rand() < 0.5:
            M_chaos2[i, j] = (np.random.randn()) * 2 / np.sqrt(0.5* N)

z1_list = []
z2_list = []

a=1
b=1/3

M_std_list = []
M_chaos_std_list = []

simtime_len = 130#0
chaos_std = np.std(M_chaos)
w_chaos = np.random.randn(N)/np.sqrt(N)
tau_chaos = 10

w_readout = np.random.randn(N)/np.sqrt(N)#*0.2

input_low_passed = np.zeros(N)
prediction = np.zeros(N)
r_prediction = np.zeros(N)
y = 0
n_rep = 4000
target_list = np.zeros(simtime_len*n_rep)
coeef = 2*(np.random.rand(12)-0.5)
for jj in tqdm(range(n_rep), desc="[training]"):

    if jj % 4 == 0:
        control_input1 = 1
        control_input2 = 0
        control_input3 = 0
        control_input4 = 0
        x=x01
        r=r01

    elif jj % 4 == 1:
        control_input1 = 0
        control_input2 = 1
        control_input3 = 0
        control_input4 = 0
        x=x01
        r=r01

    elif jj % 4 == 2:
        control_input1 = 0
        control_input2 = 0
        control_input3 = 1
        control_input4 = 0
        x=x01
        r=r01

    elif jj % 4 == 3:
        control_input1 = 0
        control_input2 = 0
        control_input3 = 0
        control_input4 = 1
        x=x01
        r=r01

    for i in range(simtime_len):
        if control_input1==1:
            z = coeef[0] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi))+coeef[1]*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))+coeef[2]*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*4*30 * i / (50 * 2 * np.pi))
        elif control_input2==1:
            z = coeef[3] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi))+coeef[4]*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))+coeef[5]*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*4*30 * i / (50 * 2 * np.pi))
        elif control_input3==1:
            z = coeef[6] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi))+coeef[7]*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))+coeef[8]*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*4*30 * i / (50 * 2 * np.pi))
        elif control_input4==1:
            z = coeef[9] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi))+coeef[10]*np.sin(0.5*2*30 * i / (50 * 2 * np.pi))+coeef[11]*np.sin(0.5*3*30 * i / (50 * 2 * np.pi))+1*np.sin(0.5*4*30 * i / (50 * 2 * np.pi))

        M_term = np.dot(M,r)
        Chaos_term = np.dot(M_chaos,r)
        rec_term = M_term + Chaos_term
        x = (1.0 - dt / tau) * x + (rec_term+control_input1*w_control + control_input2*w_control2+ control_input3*w_control3+ control_input4*w_control4)/tau

        M = learning(M,M_term, r, win * y + Chaos_term)
        r = np.tanh(x)
        y = np.dot(w_readout,r)

        w_readout = learning_readout(w_readout, y, r, z)

x = x01#np.random.randn(N)
r = r01#np.tanh(x)

x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)
Q = np.random.randn(N,N)/np.sqrt(N)

simtime_len = 130


z1_list = np.zeros(simtime_len)
z2_list = np.zeros(simtime_len)
z3_list = np.zeros(simtime_len)
z4_list = np.zeros(simtime_len)
y_list = np.zeros(simtime_len)
y_list2 = np.zeros(simtime_len)
y_list3 = np.zeros(simtime_len)
y_list4 = np.zeros(simtime_len)
for i in tqdm(range(simtime_len), desc="[testing]"):
    if i==0:

        x = x01
        r = r01
        control_input1=1
        control_input2=0
        control_input3=0
        control_input4=0

    rec_term = np.dot(M + M_chaos, r)
    x = (1.0 - dt / tau) * x + (rec_term+control_input1*w_control + control_input2*w_control2+ control_input3*w_control3+ control_input4*w_control4) / tau
    r = np.tanh(x)
    y = np.dot(w_readout,r)

    y_list[i] = y
    z1_list[i] = coeef[0] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi)) + coeef[1] * np.sin(
            0.5 * 2 * 30 * i / (50 * 2 * np.pi)) + coeef[2] * np.sin(0.5 * 3 * 30 * i / (50 * 2 * np.pi)) + 1 * np.sin(
            0.5 * 4 * 30 * i / (50 * 2 * np.pi))


for i in tqdm(range(simtime_len), desc="[testing]"):
    if i == 0:
        x = x01
        r = r01
        control_input1 = 0
        control_input2 = 1
        control_input3 = 0
        control_input4 = 0

    rec_term = np.dot(M + M_chaos, r)
    x = (1.0 - dt / tau) * x + (
                rec_term + control_input1 * w_control + control_input2 * w_control2 + control_input3 * w_control3 + control_input4 * w_control4) / tau
    r = np.tanh(x)
    y = np.dot(w_readout, r)

    y_list2[i] = y
    z2_list[i] = coeef[3] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi)) + coeef[4] * np.sin(
            0.5 * 2 * 30 * i / (50 * 2 * np.pi)) + coeef[5] * np.sin(0.5 * 3 * 30 * i / (50 * 2 * np.pi)) + 1 * np.sin(
            0.5 * 4 * 30 * i / (50 * 2 * np.pi))

for i in tqdm(range(simtime_len), desc="[testing]"):
    if i == 0:
        x = x01
        r = r01
        control_input1 = 0
        control_input2 = 0
        control_input3 = 1
        control_input4 = 0

    rec_term = np.dot(M + M_chaos, r)
    x = (1.0 - dt / tau) * x + (
                rec_term + control_input1 * w_control + control_input2 * w_control2 + control_input3 * w_control3 + control_input4 * w_control4) / tau
    r = np.tanh(x)
    y = np.dot(w_readout, r)

    y_list3[i] = y
    z3_list[i]= coeef[6] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi)) + coeef[7] * np.sin(
            0.5 * 2 * 30 * i / (50 * 2 * np.pi)) + coeef[8] * np.sin(0.5 * 3 * 30 * i / (50 * 2 * np.pi)) + 1 * np.sin(
            0.5 * 4 * 30 * i / (50 * 2 * np.pi))
for i in tqdm(range(simtime_len), desc="[testing]"):
    if i == 0:
        x = x01
        r = r01
        control_input1 = 0
        control_input2 = 0
        control_input3 = 0
        control_input4 = 1

    rec_term = np.dot(M + M_chaos, r)
    x = (1.0 - dt / tau) * x + (
                rec_term + control_input1 * w_control + control_input2 * w_control2 + control_input3 * w_control3 + control_input4 * w_control4) / tau
    r = np.tanh(x)
    y = np.dot(w_readout, r)

    y_list4[i] = y
    z4_list[i] = coeef[9] * np.sin(0.5 * 1 * 30 * i / (50 * 2 * np.pi)) + coeef[10] * np.sin(
            0.5 * 2 * 30 * i / (50 * 2 * np.pi)) + coeef[11] * np.sin(0.5 * 3 * 30 * i / (50 * 2 * np.pi)) + 1 * np.sin(
            0.5 * 4 * 30 * i / (50 * 2 * np.pi))
fig = plt.figure(figsize=(12,4))
pl.subplot(1,4,1)
pl.plot(y_list,c='steelblue')
pl.plot(z1_list,c='orangered',ls='--')
pl.subplot(1,4,2)
pl.plot(y_list2,c='steelblue')
pl.plot(z2_list,c='orangered',ls='--')
pl.subplot(1,4,3)
pl.plot(y_list3,c='steelblue')
pl.plot(z3_list,c='orangered',ls='--')
pl.subplot(1,4,4)
pl.plot(y_list4,c='steelblue')
pl.plot(z4_list,c='orangered',ls='--')
plt.savefig('readout_controlled_outputs.pdf', format='pdf',dpi=350)


"""
output_list_controled = np.zeros((4,simtime_len))
target_list_controled = np.zeros((4,simtime_len))

output_list_controled[0,:] = y_list
output_list_controled[1,:] = y_list2
output_list_controled[2,:] = y_list3
output_list_controled[3,:] = y_list4

target_list_controled[0,:] = z1_list
target_list_controled[1,:] = z2_list
target_list_controled[2,:] = z3_list
target_list_controled[3,:] = z4_list

#np.savetxt("output_list_controled.txt",output_list_controled)
#np.savetxt("target_list_controled.txt",target_list_controled)

u, s, v = np.linalg.svd(M)
#p = s/np.linalg.norm(s, ord=1)
#M_trained_approx_rank = np.exp(-np.sum(p*np.log(p)))
M_trained_approx_rank = np.sum(s>(np.mean(s)+np.std(s)))#(np.sum(s)**2)/np.sum(s**2)
#val = s/(np.sum(s))
#M_trained_approx_rank = np.prod(val**(-val))
pl.figure()
pl.bar(np.arange(N),s)
pl.show()

print(M_trained_approx_rank)
print(matrix_rank(M))
"""