# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
from scipy.stats import gaussian_kde
import copy
import shutil
from tqdm import tqdm
import scipy.stats as st
from scipy.stats._continuous_distns import _distn_names
from scipy.optimize import curve_fit
import sklearn.decomposition
from scipy import stats
from numpy import linalg as LA
from numpy.linalg import matrix_rank
import sys
from numpy.linalg import norm
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from matplotlib import animation
import numba
from scipy.stats import ortho_group
from mpl_toolkits.mplot3d import Axes3D
from scipy import signal
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
          'axes.labelsize': 10,
          'text.fontsize': 10,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': False,
          'figure.figsize': [10 / 2.54, 6 / 2.54]}



N = 500 #network size

eps = 10**(-3)#*0.5
## learning rule
@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning(w, g_V_star, PSP_star, g_V_som):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(g_V_star[i]) + g_V_som[i]) * PSP_star[l]#-eps*0.01*w[i,l]
            w[i, l] += eps * delta#-eps*0.01
            w[i,l] *= mask[i,l]

    return w

@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning_readout(w, y, r, z):

    for i in numba.prange(N):
            delta = (-(y) + z) * r[i]#-1*0.01*w[i]
            w[i] += eps * delta


    return w

dt = 1
tau = 10
M = np.zeros((N, N))
p_connect_M=1
g=0.5
scale = 1.0 / np.sqrt(p_connect_M * N)
mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if np.random.rand() < p_connect_M:
            M[i, j] = (np.random.randn()) * g * scale
            mask[i,j] = 1

win = 3*2*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5

x = np.random.randn(N)
r = np.tanh(x)

n_sect = 50

M_chaos = np.zeros((N,N))
p_connect_chaos = 0.1
scale_chaos = 1.0 / np.sqrt(p_connect_chaos* N)
g_chaos = 1.2
for i in range(N):
    for j in range(N):
        if np.random.rand() < p_connect_chaos:
            M_chaos[i, j] = (np.random.randn()) * g_chaos * scale_chaos

z1_list = []
z2_list = []

prediction = np.zeros(N)
r_prediction = np.zeros(N)

a=1
b=1/3

M_std_list = []
M_chaos_std_list = []
x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)
#M_tot_std_list = []
simtime_len = 100*1000
chaos_std = np.std(M_chaos)
w_chaos = np.random.randn(N)/np.sqrt(N)
tau_chaos = 10

w_readout = np.random.randn(N)/np.sqrt(N)

input_low_passed = np.zeros(N)
y = 0

first_term = np.zeros(simtime_len)
second_term = np.zeros(simtime_len)

r_list = np.zeros((10,simtime_len))
y_list_learning = np.zeros(simtime_len)
target_list = np.zeros(simtime_len)
feedback_list = np.zeros((N,simtime_len))
prediction_list = np.zeros((N,simtime_len))
rec_current_list = np.zeros((N,simtime_len))
chaos_list = np.zeros((N,simtime_len))
chaos_pred_prod_list = np.zeros(simtime_len)
FB = (np.dot(win, y))
error_list = np.zeros(simtime_len)


for i in tqdm(range(simtime_len), desc="[training]"):
    ####discontinuous
    ###
    #T = 60
    #z = 1.5*2*((np.sin(2*np.pi*i/T)>=0)-0.5)#0.5*np.sin(2*np.pi*i/T)/(10**(-30)+abs(1.*np.sin(2*np.pi*i/T)))
    #########

    #### Sawtooth
    #T = 100
    #z =1.5*signal.sawtooth(2*np.pi*i/T, width=0.8)
    #M = np.outer(win,w_readout)
    #Periodic
    T = 150
    z = (np.sin(2*np.pi*i/T) + np.sin(2*2*np.pi*i/T) + np.sin(3*2*np.pi*i/T))*1
    ### Sinwave
    #T = 6*tau ## short
    #T = 200*tau ##long
    #T=10*tau #normal
    #z = 1.5*np.sin(2*np.pi*i/T)

    M_term = np.dot(M, r)
    Chaos_term = np.dot(M_chaos, r)

    rec_term = M_term + Chaos_term
    x = (1.0 - dt / tau) * x + (rec_term) / tau

    # r_prediction = (1.0 - dt / tau) * r_prediction + (r) / tau
    # prediction = np.dot(M, r_prediction)
    # input_low_passed = (1.0 - dt / tau) * input_low_passed + (np.dot(win , y) ) / tau
    # print(input_low_passed[5])
    FB = (np.dot(win, y))

    M = learning(M, M_term, r, FB+1*Chaos_term)
    r = np.tanh(x)
    y = np.dot(w_readout, r)

    # M*=mask
    # print(np.max(np.dot(win , z)))

    w_readout = learning_readout(w_readout, y, r, z)

    r_list[:,i] = r[0:10]
    y_list_learning[i] = y
    target_list[i] = z

    feedback_list[:,i] =  FB#[0:20]
    prediction_list[:,i]=(M_term)#-Chaos_term#[0:20]
    chaos_list[:,i] = Chaos_term#[0:20]

    chaos_pred_prod_list[i] = np.dot(M_term,Chaos_term)/N

    error_list[i] = (y-z)


w = np.linalg.eigvals(M+M_chaos)
fig = plt.figure(figsize=(4,4))
plt.scatter(np.real(w), np.imag(w))
#plt.axes().set_aspect(1)
plt.savefig('eigen_spectrum.pdf', format='pdf',dpi=350)



fig,ax = plt.subplots(figsize = (2,2))
plt.hist2d(prediction_list[:,-1000:].reshape(N*1000),chaos_list[:,-1000:].reshape(N*1000), bins=(100, 100), cmap=plt.cm.jet)
pl.xlim(-2,2)
pl.ylim(-2,2)
plt.savefig('alignment_distribution.pdf', format='pdf',dpi=350)

fig,ax = plt.subplots(figsize = (6,2))
for i in range(5):
    pl.plot(np.arange(600), 2*i+0.8*r_list[i,0:600], lw=1.2)
    pl.plot(np.arange(700,1300,1), 2*i+0.8*r_list[i, -1300:-700], lw=1.2)
plt.vlines(x=1000, ymin=(-1), ymax=(10), color='k', ls='--')
#plt.xticks([0,200,400, 500,700,900],
            #[0,200,400,len(y_list_learning)-400,len(y_list_learning)-400+200,len(y_list_learning)-400+400])

plt.tight_layout()
plt.savefig('firing_rates.pdf', format='pdf',dpi=350)


integrated_cor_list = []

for i in range(n_sect):
    integrated_cor_list.append(np.mean(chaos_pred_prod_list[i*int(simtime_len/n_sect):(i+1)*int(simtime_len/n_sect)]))
fig,ax = plt.subplots(figsize = (2,2))
pl.plot(integrated_cor_list)
plt.savefig('alignment_dynamics.pdf', format='pdf',dpi=350)


#pl.plot(mean_error_list)
#pl.show()
## initialization
#x = np.random.randn(N)
#r = np.tanh(x)

x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)
Q = np.random.randn(N,N)/np.sqrt(N)

simtime_len = 1*1000

#prediction_list = np.zeros((N,simtime_len))
#chaos_list = np.zeros((N,simtime_len))
x_list_testing = np.zeros((N,simtime_len))

z1_list = np.zeros(simtime_len)
z2_list = np.zeros(simtime_len)
z3_list = np.zeros(simtime_len)
y_list = np.zeros(simtime_len)

for i in tqdm(range(simtime_len), desc="[testing]"):
    x = (1.0 - dt / tau) * x + (np.dot(M+M_chaos, r)) / tau
    r = np.tanh(x)
    y = np.dot(w_readout,r)

    #prediction_list[:,i] = np.dot(M, r)
    #chaos_list[:,i] = np.dot(M_chaos, r)
    x_list_testing[:,i] = r
    y_list[i] = y


fig,ax = plt.subplots(figsize = (6,2))
pl.plot(y_list)
plt.savefig('readout.pdf', format='pdf',dpi=350)

"""
np.savetxt("r_list.txt",r_list)
np.savetxt("y_list_learning.txt", y_list_learning)
np.savetxt("target_list.txt", target_list)
np.savetxt("readout_model_periodic.txt", y_list)
np.savetxt("feedback_list.txt",feedback_list)
np.savetxt("prediction_list.txt",prediction_list)
np.savetxt("chaos_list.txt",chaos_list)
"""


