# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
from scipy.interpolate import Rbf, InterpolatedUnivariateSpline
import shutil
from tqdm import tqdm
from scipy import stats
from numpy import linalg as LA
from numpy.linalg import matrix_rank
import sys
from numpy.linalg import norm
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from matplotlib import animation
from scipy import interpolate
import numba
import skvideo.io
from scipy import signal
from scipy.stats import ortho_group
from mpl_toolkits.mplot3d import Axes3D
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
          'axes.labelsize': 10,
          'text.fontsize': 10,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': False,
          'figure.figsize': [10 / 2.54, 6 / 2.54]}
np.random.seed(0)

videodata = skvideo.io.vread("movie_data.mp4")
n_itp = 10
T = 1001
videodata_down_sampled = videodata[0:T,250:650:5,40:500:5,:]/100
print(np.shape(videodata_down_sampled))
n_tar = len(videodata_down_sampled[0,:,0,0])*len(videodata_down_sampled[0,0,:,0])*len(videodata_down_sampled[0,0,0,:])
data = np.zeros((n_tar,T))
for i in range(T):
    data[:,i] = np.ndarray.flatten(videodata_down_sampled[i, :, :, :])
# setup data
data_ip_ = np.zeros((n_tar,n_itp*T))
x = np.arange(T)
xi = np.linspace(0, T, len(data_ip_[0,:]))

# use fitpack2 method
for j in range(n_tar):
    ius = InterpolatedUnivariateSpline(x, data[j,:])
    data_ip_[j,:] = ius(xi)
dt = 1
tau = 10
data_ip_ = data_ip_[:,100:-100]
data_ip = np.zeros((n_tar,len(data_ip_[0,:])))
data_inst = data_ip_[:,0]
simtime_len =len(data_ip[0,:])

tau_filter = 1
for i in range(simtime_len):
    data_inst = (1-dt/tau_filter)*data_inst + data_ip_[:,i]/tau_filter
    data_ip[:,i] = data_inst

data_mean = np.mean(data_ip , axis=1)
data_std = np.std(data_ip , axis=1)
data_ip = ((data_ip.T-data_mean)/data_std).T

"""
plt.subplot(1, 1, 1)
#plt.plot(x, ((data.T-data_mean)/data_std).T[0,:], 'bo')
for i in range(n_tar):
    plt.plot( data_ip[i,:], '-')
pl.show()
"""
fig = plt.figure(figsize=(10, 10))
ims = []

# pop_act = np.zeros(num_cluster)
for i in range(int(simtime_len/n_itp)):


    # for j in range(num_cluster):
    # pop_act[j]=np.mean(y[j*int(N_exc/num_cluster):min(N_exc,(j+1)*int(N_exc/num_cluster))])
    img = plt.imshow((100*np.reshape(data_ip[:,i*n_itp]*data_std+data_mean,(len(videodata_down_sampled[0,:,0,0]),len(videodata_down_sampled[0,0,:,0]),len(videodata_down_sampled[0,0,0,:])))).astype(np.uint8))
    ims.append([img])

ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True,
                                repeat_delay=100)
ani.save('anim_target.gif', writer="pillow")



#pl.imshow(videodata_down_sampled[1000,:,:,:])
#pl.show()

N = 800 #network size

eps = 10**(-3)#*0.5
## learning rule
@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning(w, g_V_star, PSP_star, g_V_som):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(g_V_star[i]) + g_V_som[i]) * PSP_star[l]#-1*0.1*w[i,l]
            w[i, l] += eps * delta
            w[i,l] *= mask[i,l]

    return w



@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning_readout(w, y, r, z):

    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(y[i]) + z[i]) * r[l]#-1*0.1*w[i,l]
            w[i,l] += eps * delta


    return w

M = np.zeros((N, N))
p=1
g=0.5
scale = 1.0 / np.sqrt(p * N)
mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if np.random.rand() < p:
            M[i, j] = (np.random.randn()) * g * scale
            mask[i,j] = 1

q, r = np.linalg.qr(2*(np.random.rand(N,3)-0.5))
#print(np.dot(q[:,0],q[:,0]))

print(n_tar)
#Win = ortho_group.rvs(dim=N)*np.sqrt(N)*2
win = 3*2*(np.random.rand(N,n_tar)-0.5)/np.sqrt(n_tar)#*2*(np.random.rand()<1)*0.5

g_L=1/tau
g_d=0.7

x0 = np.random.randn(N)
r0 = np.tanh(x0)

M_chaos = np.zeros((N,N))
M_chaos2 = np.zeros((N,N))
scale_chaos = 1.0 / np.sqrt(0.1* N)
for i in range(N):
    for j in range(N):
        if np.random.rand() < 0.1:
            M_chaos[i, j] = (np.random.randn()) * 1.2 * scale_chaos
        if np.random.rand() < 0.5:
            M_chaos2[i, j] = (np.random.randn()) * 2 / np.sqrt(0.5* N)
z1_list = []
z2_list = []

a=1
b=1/3

M_std_list = []
M_chaos_std_list = []

#M_tot_std_list = []

chaos_std = np.std(M_chaos)
w_chaos = np.random.randn(N)/np.sqrt(N)
tau_chaos = 10

w_readout = np.random.randn(n_tar,N)/np.sqrt(N)#*0.2
w_control = 0.*2*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5
input_low_passed = np.zeros(N)
y = np.zeros(n_tar)
z = np.zeros(n_tar)
x=x0
r=r0
#w_control = 0*1*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5
prediction = np.zeros(N)
r_prediction = np.zeros(N)
tot_loop = 1500
M_error_list = np.zeros(tot_loop)
readout_error_list = np.zeros(tot_loop)
for loop in tqdm(range(tot_loop), desc="[training]"):

    x = x0
    r = r0
    y = data_ip[:, 0]
    #prediction = np.zeros(N)
    #r_prediction = np.zeros(N)
    for i in range(simtime_len):
        z = data_ip[:,i]#*0
        #z+=1*np.sin(0.5*1*30 * i / (50 * 2 * np.pi))
        M_term = np.dot(M, r)
        Chaos_term = np.dot(M_chaos, r)

        rec_term = M_term + Chaos_term

        x = (1.0 - dt / tau) * x + (rec_term+w_control) / tau

        #r_prediction = (1.0 - dt / tau) * r_prediction + (r) / tau
        #prediction = np.dot(M, r_prediction)
        #input_low_passed = (1.0 - dt / tau) * input_low_passed + (np.dot(win , y) ) / tau
        #print(input_low_passed[5])
        FB = (np.dot(win , y) )
        M = learning(M, M_term, r, FB+Chaos_term)
        r = np.tanh(x)
        y = np.dot(w_readout, r)

        # M*=mask
        #print(np.max(np.dot(win , z)))
        w_readout = learning_readout(w_readout, y, r, z)

        M_error_list[loop] += np.mean((FB-(M_term-Chaos_term))**2)/simtime_len
        readout_error_list[loop] += np.mean((z-y)**2)/simtime_len
pl.figure()
pl.plot(M_error_list,c='r')
pl.plot(readout_error_list,c='b')
plt.savefig('movie_error.pdf', format='pdf',dpi=350)


pl.figure()
pl.plot(M_error_list/np.max(M_error_list),c='r')
pl.plot(readout_error_list/np.max(readout_error_list),c='b')
plt.savefig('movie_error_norm.pdf', format='pdf',dpi=350)

##testing
#x=x0
#r=r0

r_list = np.zeros((N,simtime_len))
output_list = np.zeros((n_tar,simtime_len))
target_list = np.zeros((n_tar,simtime_len))
y_list = np.zeros((simtime_len,len(videodata_down_sampled[0,:,0,0]),len(videodata_down_sampled[0,0,:,0]),len(videodata_down_sampled[0,0,0,:])))
x = x0
r = r0
y = data_ip[:, 0]
for i in range(simtime_len):

    rec_term = np.dot(M+M_chaos, r)
    x = (1.0 - dt / tau) * x + (rec_term+w_control) / tau#*2*(np.random.rand()<1)*0.5) / tau

    r = np.tanh(x)
    y = np.dot(w_readout, r)
    r_list[:,i] = r

    #print(np.shape(y),np.shape(data_mean))

    y_list[i,:,:,:] = np.reshape(y*data_std+data_mean,(len(videodata_down_sampled[0,:,0,0]),len(videodata_down_sampled[0,0,:,0]),len(videodata_down_sampled[0,0,0,:])))
    #print(y_list[i,:,:,:])
    output_list[:, i] = y
    target_list[:, i] = data_ip[:,i]

fig = plt.figure(figsize=(8, 4))
plt.subplot(1, 1, 1)
#plt.plot(x, data[10,:], 'bo')
plt.plot(target_list[0,:], '--',c='r')
plt.plot(output_list[0,:], '-',c='b')
plt.savefig('movie_output_target.pdf', format='pdf',dpi=350)



fig = plt.figure(figsize=(10, 10))
ims = []

# pop_act = np.zeros(num_cluster)
for i in range(int(simtime_len/n_itp)):

    # for j in range(num_cluster):
    # pop_act[j]=np.mean(y[j*int(N_exc/num_cluster):min(N_exc,(j+1)*int(N_exc/num_cluster))])
    img = plt.imshow((100*y_list[n_itp*i,:,:,:]).astype(np.uint8))
    ims.append([img])

ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True,
                                repeat_delay=100)
ani.save('movie_data_trained _model.gif', writer="pillow")
"""
