"""
Reproduction script for:
"Non-Hermitian Skin Effect on Directed Networks: Rigorous Open-Boundary
Turing Results and a Testable Hypothesis for the Kinetics-Independence
of Tau Pathology Localisation in Alzheimer's Disease"
(Liu & Liang, manuscript NHSE_tau)

This script reproduces all numerical results in Section 5 (Theorems 1-3
verification, Figures 1 and 2) using only NumPy and Matplotlib.

Convention note. Following the Hatano-Nelson convention used in the
non-Hermitian skin-effect literature (Yao & Wang 2018; Yokomizo &
Murakami 2019), the directed-ring Laplacian under open boundary
conditions is implemented with a uniform on-site term -(t_R + t_L)
applied to every node, including the two boundary nodes. This is
equivalent to compensating the lost out-degree at the boundaries with
virtual self-loops and is required for consistency with the analytical
spectrum of Equation (13) in the main text.

Author: Han Liu
Email: uctqhhl@ucl.ac.uk
"""

import numpy as np
import matplotlib.pyplot as plt


# A. Operators -------------------------------------------------------------

def directed_ring_pbc(N, tR, tL):
    """Directed-ring Laplacian under PBC (Equations 2-3 of the manuscript)."""
    L = np.zeros((N, N))
    for i in range(N):
        L[i, (i - 1) % N] = tR
        L[i, (i + 1) % N] = tL
        L[i, i] = -(tR + tL)
    return L


def directed_ring_obc(N, tR, tL):
    """Directed-ring Laplacian under OBC, Hatano-Nelson convention
    (uniform on-site term)."""
    L = np.zeros((N, N))
    for i in range(N):
        if i - 1 >= 0:
            L[i, i - 1] = tR
        if i + 1 < N:
            L[i, i + 1] = tL
        L[i, i] = -(tR + tL)
    return L


# B. Theorem 1 -------------------------------------------------------------

def verify_theorem1(N_list=(20, 30, 50), tR=3.0, tL=1.0):
    lam_min = -(np.sqrt(tR) + np.sqrt(tL)) ** 2
    lam_max = -(np.sqrt(tR) - np.sqrt(tL)) ** 2
    print(f"Analytical OBC interval Sigma_OBC = [{lam_min:.6f}, {lam_max:.6f}]")
    for N in N_list:
        L = directed_ring_obc(N, tR, tL)
        ev = np.linalg.eigvals(L)
        max_imag = np.max(np.abs(ev.imag))
        ev_sorted = np.sort(ev.real)
        theta = np.arange(1, N + 1) * np.pi / (N + 1)
        lam_an = np.sort(2 * np.sqrt(tR * tL) * np.cos(theta) - (tR + tL))
        err = np.max(np.abs(ev_sorted - lam_an))
        print(f"  N = {N:>3d}: max|Im(lambda)| = {max_imag:.2e}, "
              f"max|num - analytical| = {err:.3e}")


# C. Theorem 2 + terminal node --------------------------------------------

def obc_eigenstate(N, n, tR, tL):
    """Eq. (14): psi_j^(n) = (r*)^j sin(j theta_n) / sin(theta_n)."""
    rstar = np.sqrt(tR / tL)
    theta = n * np.pi / (N + 1)
    j = np.arange(1, N + 1)
    return (rstar ** j) * np.sin(j * theta) / np.sin(theta)


def verify_theorem2_and_terminal_node(N=50, tR=3.0, tL=1.0, atol=1e-9):
    rstar = np.sqrt(tR / tL)
    print(f"r* = sqrt(tR/tL) = {rstar:.6f}, (r*)^N = {rstar ** N:.6e}")
    max_dev_id = 0.0
    max_dev_amp = 0.0
    for n in range(1, N + 1):
        theta = n * np.pi / (N + 1)
        max_dev_id = max(max_dev_id,
                         abs(abs(np.sin(N * theta)) - abs(np.sin(theta))))
        psi = obc_eigenstate(N, n, tR, tL)
        F_N = abs(psi[-1])
        expected = rstar ** N
        rel = abs(F_N - expected) / max(1.0, expected)
        max_dev_amp = max(max_dev_amp, rel)
        assert rel < atol, f"Mode n={n}: F_N={F_N:.3e} != (r*)^N={expected:.3e}"
    print(f"Algebraic identity verified for all {N} modes "
          f"(max deviation = {max_dev_id:.2e}).")
    print(f"Terminal-node amplitude F_N = (r*)^N for every mode "
          f"(max relative deviation = {max_dev_amp:.2e}).")


# D. Theorem 3 -------------------------------------------------------------

def boundary_region(profile, eps=0.5):
    threshold = eps * np.max(np.abs(profile))
    return tuple(int(j) for j in np.where(np.abs(profile) >= threshold)[0] + 1)


def verify_theorem3(N=50, tR=3.0, tL=1.0, n_star=None, eps=0.5):
    if n_star is None:
        n_star = N
    base = obc_eigenstate(N, n_star, tR, tL)
    chi_values = np.logspace(-1, 1, 5)
    regions = [boundary_region(chi * base, eps=eps) for chi in chi_values]
    print(f"Boundary regions Omega_{eps} across kinetic rescalings "
          f"(n* = {n_star}, fixed):")
    for chi, region in zip(chi_values, regions):
        print(f"  |chi| = {chi:>6.3g}: Omega = {region}")
    assert all(r == regions[0] for r in regions), "Theorem 3 violated"
    print("Theorem 3 verified: Omega_eps invariant under kinetic deformation.")


# E. Figures ---------------------------------------------------------------

def make_figure1(tR=3.0, tL=1.0, N_list=(20, 30, 50),
                 savepath="figure1_spectrum.png"):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
    k = np.linspace(0, 2 * np.pi, 400)
    pbc_real = (tR + tL) * (np.cos(k) - 1)
    pbc_imag = (tL - tR) * np.sin(k)
    lam_min = -(np.sqrt(tR) + np.sqrt(tL)) ** 2
    lam_max = -(np.sqrt(tR) - np.sqrt(tL)) ** 2
    for ax, N in zip(axes, N_list):
        ev_pbc = np.linalg.eigvals(directed_ring_pbc(N, tR, tL))
        ev_obc = np.linalg.eigvals(directed_ring_obc(N, tR, tL))
        ax.plot(pbc_real, pbc_imag, "k-", lw=1.2, label="PBC ellipse (Eq. 9)")
        ax.plot([lam_min, lam_max], [0, 0], color="darkred", lw=2.5,
                label="OBC interval (Eq. 16)")
        ax.scatter(ev_pbc.real, ev_pbc.imag, s=30, alpha=0.7,
                   label="PBC eigenvalues")
        ax.scatter(ev_obc.real, ev_obc.imag, marker="D", s=40,
                   color="red", edgecolors="k", label="OBC eigenvalues")
        ax.set_title(f"N = {N}")
        ax.set_xlabel(r"Re($\lambda$)")
        if ax is axes[0]:
            ax.set_ylabel(r"Im($\lambda$)")
        ax.axhline(0, color="grey", lw=0.5)
        ax.legend(fontsize=8, loc="best")
    fig.tight_layout()
    fig.savefig(savepath, dpi=200)
    plt.close(fig)
    print(f"Figure 1 saved to {savepath}")


def make_figure2(N=50, tR=3.0, tL=1.0, savepath="figure2_eigenstates.png"):
    rstar = np.sqrt(tR / tL)
    j = np.arange(1, N + 1)
    envelope = rstar ** j
    fig, axes = plt.subplots(1, 2, figsize=(13, 5))
    for n in range(1, N + 1):
        psi = obc_eigenstate(N, n, tR, tL)
        axes[0].plot(j, np.abs(psi), color="steelblue", alpha=0.25)
        axes[1].semilogy(j, np.abs(psi), color="steelblue", alpha=0.25)
    axes[0].plot(j, envelope, "k--", lw=2, label=r"Envelope $(r^*)^j$")
    axes[1].semilogy(j, envelope, "k--", lw=2, label=r"Envelope $(r^*)^j$")
    for ax in axes:
        ax.set_xlabel("Node index j")
        ax.legend(fontsize=10)
    axes[0].set_ylabel(r"$|\psi_j^{(n)}|$ (linear)")
    axes[1].set_ylabel(r"$|\psi_j^{(n)}|$ (log)")
    axes[0].set_title("Linear scale: shared envelope, mode-dependent oscillation")
    axes[1].set_title(rf"Log scale: parallel slopes "
                      rf"$\ln(r^*) \approx {np.log(rstar):.3f}$")
    fig.tight_layout()
    fig.savefig(savepath, dpi=200)
    plt.close(fig)
    print(f"Figure 2 saved to {savepath}")


if __name__ == "__main__":
    np.random.seed(0)
    print("=" * 70)
    print("Theorem 1: real OBC spectrum")
    print("=" * 70)
    verify_theorem1()
    print()
    print("=" * 70)
    print("Theorem 2 + terminal-node algebraic identity")
    print("=" * 70)
    verify_theorem2_and_terminal_node()
    print()
    print("=" * 70)
    print("Theorem 3: boundary-region invariance under kinetic deformation")
    print("=" * 70)
    verify_theorem3()
    print()
    print("=" * 70)
    print("Generating Figure 1 and Figure 2")
    print("=" * 70)
    make_figure1()
    make_figure2()
    print("All numerical checks completed.")
