# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
import shutil
from tqdm import tqdm
from scipy import stats
import sklearn.decomposition
from sklearn.decomposition import PCA
from matplotlib.gridspec import GridSpec
from numpy import linalg as LA
from numpy.linalg import matrix_rank
import sys
from numpy.linalg import norm
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from matplotlib import animation
import numba
from scipy.stats import ortho_group
from mpl_toolkits.mplot3d import Axes3D
from scipy import signal
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
          'axes.labelsize': 10,
          'text.fontsize': 10,
          'legend.fontsize': 10,
          'xtick.labelsize': 10,
          'ytick.labelsize': 10,
          'text.usetex': False,
          'figure.figsize': [10 / 2.54, 6 / 2.54]}

N = 600 #network size


eps = 10**(-3)#*0.5
## learning rule
@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning(w, g_V_star, PSP_star, g_V_som):
    for i in numba.prange(len(w[:, 0])):
        for l in numba.prange(len(w[0, :])):
            delta = (-(g_V_star[i]) + g_V_som[i]) * PSP_star[l]
            w[i, l] += eps * delta
            w[i,l] *= mask[i,l]

    return w

@numba.njit(parallel=True, fastmath=True, nogil=True)
def learning_readout(w, y, r, z):

    for i in numba.prange(N):
            delta = (-(y) + z) * r[i]
            w[i] += eps * delta
    return w

dt = 1
tau = 10
M = np.zeros((N, N))
p=1
g=0.5
scale = 1.0 / np.sqrt(p * N)
mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if np.random.rand() < p:
            M[i, j] = (np.random.randn()) * g * scale
            mask[i,j] = 1

win = 3*2*(np.random.rand(N)-0.5)#*2*(np.random.rand()<1)*0.5


x = np.random.randn(N)
r = np.tanh(x)

M_chaos = np.zeros((N,N))
scale_chaos = 1.0 / np.sqrt(0.1* N)
for i in range(N):
    for j in range(N):
        if np.random.rand() < 0.1:
            M_chaos[i, j] = (np.random.randn()) * 1.2 * scale_chaos

z1_list = []
z2_list = []

prediction = np.zeros(N)
r_prediction = np.zeros(N)

M_std_list = []
M_chaos_std_list = []
x_chaos = np.random.randn(N)
r_chaos = np.tanh(x)

simtime_len = 300
chaos_std = np.std(M_chaos)
w_chaos = np.random.randn(N)/np.sqrt(N)
tau_chaos = 10

w_readout = np.random.randn(N)/np.sqrt(N)

input_low_passed = np.zeros(N)
y = 0

w_input1 = 2*(np.random.rand(N)-0.5)
w_input2 = 2*(np.random.rand(N)-0.5)
x0 = np.random.randn(N)
r0 = np.tanh(x0)

target_list = []
for kk in range(4):
    target_list.append(50+10*kk)
T=15
for loop in tqdm(range(50000), desc="[training]"):#50000
    x = x0
    r = r0
    y=0
    target_delay = target_list[np.random.randint(len(target_list))]
    input1_vec = (np.exp(-(np.arange(simtime_len)-30)**2/T**2)-0.)*1
    input2_vec = (np.exp(-(np.arange(simtime_len)-30-target_delay)**2/T**2)-0.)*1
    target_vec = (np.exp(-(np.arange(simtime_len)-30-2*target_delay)**2/T**2)-0.)*1


    for i in range(simtime_len):
        input1 = input1_vec[i]
        input2 = input2_vec[i]
        target = target_vec[i]
        """
        input1 = -0.5
        input2 = -0.5
        target = -0.5
        if i>=0 and i<50:
            input1 = (np.sin(i*2*np.pi/T)-0.25)*2
        if i>=target_delay and i<target_delay+50:
            input2 = (np.sin((i-target_delay) * 2 * np.pi / T)-0.25)*2
        if i >= 2*target_delay and i < 2*target_delay + 50:
            target = (np.sin((i - 2*target_delay) * 2 * np.pi / T)-0.25)*2
        """
        M_term = np.dot(M,r)
        Chaos_term = np.dot(M_chaos,r)
        rec_term = M_term + Chaos_term
        x = (1.0 - dt / tau) * x + (rec_term + input1*w_input1 + input2*w_input2)/tau
        M = learning(M, (M_term), r, (np.dot(win, y)+Chaos_term))

        r = np.tanh(x)
        y = np.dot(w_readout,r)

        w_readout = learning_readout(w_readout, y, r, target)

n_sim = 70
simtime_len = 300
input1_list = np.zeros((n_sim,simtime_len))
input2_list = np.zeros((n_sim,simtime_len))
output_list = np.zeros((n_sim,simtime_len))
r_list = np.zeros((N,n_sim*simtime_len))
for loop in tqdm(range(n_sim), desc="[training]"):
    x = x0
    r = r0
    y=1
    target_delay = 30+loop
    input1_vec = (np.exp(-(np.arange(simtime_len)-30)**2/T**2)-0.)*1
    input2_vec = (np.exp(-(np.arange(simtime_len)-30-target_delay)**2/T**2)-0.)*1
    target_vec = (np.exp(-(np.arange(simtime_len)-30-2*target_delay)**2/T**2)-0.)*1

    for i in range(simtime_len):
        input1 = input1_vec[i]
        input2 = input2_vec[i]
        target = target_vec[i]
        """
        input1 = -0.5
        input2 = -0.5
        target = -0.5
        if i>=0 and i<50:
            input1 = (np.sin(i*2*np.pi/T)-0.25)*2
        if i>=target_delay and i<target_delay+50:
            input2 = (np.sin((i-target_delay) * 2 * np.pi / T)-0.25)*2
        if i >= 2*target_delay and i < 2*target_delay + 50:
            target = (np.sin((i - 2*target_delay) * 2 * np.pi / T)-0.25)*2
        """
        M_term = np.dot(M,r)
        Chaos_term = np.dot(M_chaos,r)
        rec_term = M_term + Chaos_term
        x = (1.0 - dt / tau) * x + (rec_term + input1*w_input1 + input2*w_input2)/tau

        r = np.tanh(x)
        y = np.dot(w_readout,r)

        r_list[:,loop*simtime_len+i] = r

        input1_list[loop,i] = input1
        input2_list[loop,i] = input2
        output_list[loop,i] = y


np.savetxt("delay_matching_input1_list.txt", input1_list)
np.savetxt("delay_matching_input2_list.txt", input2_list)
np.savetxt("delay_matching_output_list.txt", output_list)
np.savetxt("delay_matching_r_list.txt", r_list)


input1_list = np.loadtxt("delay_matching_input1_list.txt")
input2_list = np.loadtxt("delay_matching_input2_list.txt")

n_sim = len(input1_list[:,0])
simtime_len = 300

output_list_ = np.zeros((3,n_sim,simtime_len))#np.loadtxt("delay_matching_output_list1.txt")
output_list_[0,:,:] = np.loadtxt("delay_matching_output_list1.txt")
output_list_[1,:,:] = np.loadtxt("delay_matching_output_list2.txt")
output_list_[2,:,:] = np.loadtxt("delay_matching_output_list.txt")


output_list = np.mean(output_list_,axis=0)

r_list = np.loadtxt("delay_matching_r_list.txt")




N=len(r_list[:,0])


max1 = np.zeros(N)
min1 = np.zeros(N)
for i in range(N):
    max1[i] = np.max(r_list[i,:])
    min1[i] = np.min(r_list[i,:])
avg_norm1 = np.zeros((N,simtime_len))

for i in range(N):
    avg_norm1[i,:] = (r_list[i,0:simtime_len]-min1[i])/(max1[i]-min1[i])

t = np.zeros(N)
for j in range(N):
    arg = np.angle(np.dot(avg_norm1[j,:],np.exp(np.arange(simtime_len)/(simtime_len)*2*np.pi*1j))/sum(avg_norm1[j,:]))
    if arg<0:
        arg += 2*np.pi
    t[j] = simtime_len/(2*np.pi)*arg

index = np.zeros(N)

index = np.argsort(t)
avg_sorted = np.zeros((N,simtime_len))
for i in range(N):
    avg_sorted[i,:] = avg_norm1[int(index[i]),:]

#colors = ["tomato","coral","darkorange","orange","gold","yellowgreen","limegreen","mediumturquoise","lightskyblue","deepskyblue","cornflowerblue"]
colors = ["orangered","limegreen","#ffd343","royalblue"]

max_time_input1 = np.zeros(n_sim)
max_time_input2 = np.zeros(n_sim)
max_time_output = np.zeros(n_sim)

for i in range(n_sim):
    max_time_input1[i] = np.argmax(input1_list[i,:])#np.min(np.where(input1_list[i,:]>0.5)[0])
    max_time_input2[i] = np.argmax(input2_list[i,:])#np.min(np.where(input2_list[i,:]>0.5)[0])
    max_time_output[i] = np.argmax(output_list[i,:])#np.min(np.where(output_list[i,:]>0.5)[0])

input_interval = np.zeros(n_sim)
output_interval = np.zeros(n_sim)

for i in range(n_sim):
    input_interval[i] = max_time_input2[i] - max_time_input1[i]
    output_interval[i] = max_time_output[i] - max_time_input2[i]

fig = plt.figure(figsize=(8, 2))
gs = GridSpec(nrows=5, ncols=15)
ax0 = fig.add_subplot(gs[0:2, 0:7])
ax0.plot(input1_list[20, :])
ax0.plot(input2_list[20, :])
ax0.plot(output_list[20, :])
ax0.spines['top'].set_visible(False)
ax0.spines['right'].set_visible(False)
ax0.spines['bottom'].set_visible(False)
# axs[0]spines['bottom'].set_visible(False)
ax0.spines['left'].set_visible(False)
# axs[0].get_xaxis().set_ticks([])
ax0.get_yaxis().set_ticks([])
ax0.get_xaxis().set_ticks([])
plt.title('A',loc='left',size = 15)
ax1 = fig.add_subplot(gs[2:4, 0:7])
ax1.plot(input1_list[n_sim-20, :])
ax1.plot(input2_list[n_sim-20, :])
ax1.plot(output_list[n_sim-20, :])
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
# axs[0]spines['bottom'].set_visible(False)
ax1.spines['left'].set_visible(False)
# axs[0].get_xaxis().set_ticks([])
ax1.get_yaxis().set_ticks([])
ax1.get_xaxis().set_ticks([])
plt.tight_layout()

ax2 = fig.add_subplot(gs[0:4, 7:11])
ax2.plot(input_interval,output_interval,lw=1.5,c='gray')
#print(len(output_interval))
ax2.axvspan(50, 80, 0., 1, color="gray",alpha=0.1)
ax2.plot(np.arange(30,100,1),np.arange(30,100,1),c='k',ls='--',lw=0.7)
for i in range(4):
        pl.plot(50+10*i,output_interval[21+10*i],'s',markersize=7,c=colors[i], mec = 'k',markeredgewidth=0.5)
pl.xlim(30,100)
pl.ylim(30,100)
plt.xticks(np.arange(30, 100, 20))
plt.yticks(np.arange(30, 100, 20))
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.title('B',loc='left',size = 15)
# axs[0]spines['bottom'].set_visible(False)
#ax2.spines['left'].set_visible(False)
# axs[0].get_xaxis().set_ticks([])
#ax2.get_yaxis().set_ticks([])

ax2.set_aspect("equal")

#####


###


####
ax3 = fig.add_subplot(gs[0:4,11:15], projection='3d')
pca = sklearn.decomposition.PCA(n_components=6)
firing_mat = r_list[:, :].T
pca.fit(firing_mat)

X_pca = pca.transform(firing_mat)


for i in range(4):

    for jj in range(simtime_len):

        if input2_list[20+10*i,jj-1]<0.9 and input2_list[20+10*i,jj]>0.9 :
            onset_time = jj#np.argmax(output_list[10*i,:])-45#np.min(np.where(output_list[10*i,:]>0)[0])
            offset_time = onset_time + 100
            break
    for jj in range(simtime_len):
        if output_list[20+10*i,jj-1]<0.9 and output_list[20+10*i,jj]>0.9 :
            onset_time2 = jj#np.argmax(output_list[10*i,:])-45#np.min(np.where(output_list[10*i,:]>0)[0])
            offset_time2 = onset_time2 + 100
            break

    #offset_time = np.argmax(output_list[10*i,:])+45#np.min(np.where(output_list[10*i,:]>0)[0])
    ax3.plot(X_pca[(20+10*i) * simtime_len+20:(20+10*i) * simtime_len+onset_time2, 0],
                    X_pca[(20+10*i) * simtime_len+20:(20+10*i) * simtime_len+onset_time2, 1],
                    X_pca[(20+10*i) * simtime_len+20:(20+10*i) * simtime_len+onset_time2, 2], color=colors[i], lw=1)



    ax3.plot(X_pca[(20+10*i) * simtime_len+onset_time , 0],
             X_pca[(20+10*i) * simtime_len+onset_time , 1],
             X_pca[(20+10*i) * simtime_len+onset_time , 2],markeredgecolor=colors[(i)],
    markeredgewidth=2,
             markerfacecolor='white', lw=0, marker='o',markersize=7)

    ax3.plot(X_pca[(20+10*i) * simtime_len+onset_time2 , 0],
             X_pca[(20+10*i) * simtime_len+onset_time2 , 1],
             X_pca[(20+10*i) * simtime_len+onset_time2 , 2],markeredgecolor=colors[(i)],
    markeredgewidth=2,
             markerfacecolor='white', lw=0, marker='s',markersize=7)

ax3.set_xticks([])  # removes the ticks... great now the rest of it
ax3.set_yticks([])
ax3.set_zticks([])
# ax.grid(False)             # this does nothing....
# ax.set_frame_on(False)     # this does nothing....
ax3.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax3.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax3.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# Get rid of the spines
#ax3.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
#ax3.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
#ax3.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax3.set_xlabel('PC1')
ax3.set_ylabel('PC2')
ax3.set_zlabel('PC3')
ax3.xaxis.labelpad=-12
ax3.yaxis.labelpad=-12
ax3.zaxis.labelpad=-12
#ax3.view_init(120,-70)
plt.title('C',loc='left',size = 15)

plt.tight_layout()
plt.savefig('Figure_delay_matching_task.pdf', format='pdf',dpi=350)
pl.show()