import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

# Parameters
a = 1.0
R = 1.0
K_values = [0.2, 0.8, 3.0, 6.0]
x0_values = [-5, -2, 0, 2, 5]
t_span = (0, 20)
t_eval = np.linspace(0, 20, 2000)

# Sigmoid
def sigmoid(x):
    return 1 / (1 + np.exp(-a * x))

# Wetware dynamics
def wetware(t, x, K):
    return K * sigmoid(x) - R * x
def settling_time(t, x, x_star, eps=0.01):
    for i in range(len(t)):
        if abs(x[i] - x_star) < eps:
            if np.all(np.abs(x[i:] - x_star) < eps):
                return t[i]
    return np.nan
# Lyapunov function
def lyapunov(x, K):
    s = np.linspace(0, x, 1000)
    integrand = R * s - K * sigmoid(s)
    return np.trapezoid(integrand, s)

# ---- State trajectories ----
plt.figure()
for K in K_values:
    for x0 in x0_values:
        sol = solve_ivp(wetware, t_span, [x0], t_eval=t_eval, args=(K,))
        plt.plot(sol.t, sol.y[0])

plt.xlabel("Time")
plt.ylabel("x(t)")
plt.title("State trajectories")
plt.grid(True)
plt.show()

# ---- Lyapunov decay ----
plt.figure()
K = 2.0
for x0 in x0_values:
    sol = solve_ivp(wetware, t_span, [x0], t_eval=t_eval, args=(K,))
    V_vals = [lyapunov(x, K) for x in sol.y[0]]
    plt.plot(sol.t, V_vals)

plt.xlabel("Time")
plt.ylabel("V(x)")
plt.title("Lyapunov function decay")
plt.grid(True)
plt.show()

# ---- Phase portrait ----
plt.figure()
x_vals = np.linspace(-6, 6, 1000)
for K in K_values:
    dx = K * sigmoid(x_vals) - R * x_vals
    plt.plot(x_vals, dx, label=f"K={K}")

plt.axhline(0)
plt.xlabel("x")
plt.ylabel("dx/dt")
plt.title("Phase portrait")
plt.legend()
plt.grid(True)
plt.show()
# -----------------------------
# Quantitative convergence benchmark
# Settling time vs K
# -----------------------------

eps = 0.01
K_test = [0.2, 0.5, 1.0, 2.0, 4.0, 6.0]
x0 = 5.0

settling_times = []

for K in K_test:
    sol = solve_ivp(
        wetware,
        t_span,
        [x0],
        t_eval=t_eval,
        args=(K,)
    )
    
    # equilibrium (numerical)
    x_star = sol.y[0][-1]
    
    Ts = settling_time(sol.t, sol.y[0], x_star, eps)
    settling_times.append(Ts)

# Plot settling time vs K
plt.figure()
plt.plot(K_test, settling_times, marker='o')
plt.xlabel("K (excitation / plasticity)")
plt.ylabel("Settling time T_s")
plt.title("Quantitative convergence benchmark")
plt.grid(True)

plt.savefig("../figures/Fig7_settling_time_vs_K.png", dpi=300)
plt.show()
# Quantitative convergence benchmark
# Settling time vs K
...

