#!/usr/bin/env python3

"""
spectral_compactness_example.py

Minimal, self-contained experiments illustrating how singular-spectrum tail shape
(algebraic vs exponential decay) affects:
  - effective rank (trace-norm mass concentration)
  - "eigenvalue clarity" after coarse quantization
  - iterative finite-precision error accumulation
  - a simple denoising benchmark via spectral filtering

This is a synthetic demo intended to accompany an NPL-style manuscript draft.
It does NOT train a neural network; it isolates spectral effects only.

Usage:
  python spectral_compactness_example.py --help
"""

from __future__ import annotations

import argparse
import math
from dataclasses import dataclass
from typing import Tuple, Dict, Any

import numpy as np
import matplotlib.pyplot as plt


# -----------------------------
# Spectra + metrics
# -----------------------------

def algebraic_spectrum(n: int, alpha: float = 0.5) -> np.ndarray:
    """sigma_i = i^{-alpha}, i=1..n (monotone decreasing)."""
    i = np.arange(1, n + 1, dtype=np.float64)
    return i ** (-alpha)


def exponential_spectrum(n: int, beta: float = 0.15) -> np.ndarray:
    """sigma_i = exp(-beta*(i-1)), i=1..n (monotone decreasing)."""
    i = np.arange(0, n, dtype=np.float64)
    return np.exp(-beta * i)


def effective_rank_s1(sigma: np.ndarray, tau: float = 0.9) -> int:
    """
    Effective rank capturing tau fraction of Schatten-1 mass (sum of singular values).
    """
    s = np.sort(np.abs(sigma))[::-1]
    total = float(np.sum(s))
    if total <= 0:
        return 0
    cum = np.cumsum(s)
    return int(np.searchsorted(cum, tau * total) + 1)


def spectral_entropy(sigma: np.ndarray, eps: float = 1e-12) -> float:
    """
    Shannon entropy of normalized singular values p_i = sigma_i / sum sigma_i.
    (Natural-log units.)
    """
    s = np.abs(sigma).astype(np.float64)
    s_sum = float(np.sum(s))
    if s_sum <= 0:
        return 0.0
    p = s / s_sum
    p = np.clip(p, eps, 1.0)
    return float(-np.sum(p * np.log(p)))


def uniform_quantize(x: np.ndarray, bits: int, xmin: float, xmax: float) -> np.ndarray:
    """
    Uniform scalar quantization of x into 2^bits levels in [xmin, xmax].
    Returns quantized values in float64.
    """
    if bits <= 0:
        raise ValueError("bits must be >= 1")
    levels = 2 ** bits
    if xmax <= xmin:
        return np.full_like(x, fill_value=xmin, dtype=np.float64)
    step = (xmax - xmin) / (levels - 1)
    q = np.round((x - xmin) / step)
    q = np.clip(q, 0, levels - 1)
    return xmin + q * step


def eigenvalue_clarity(
    sigma: np.ndarray,
    bits: int = 4,
    tau: float = 0.9,
) -> float:
    """
    One concrete operationalization of "eigenvalue clarity":
      1) take the smallest r such that top-r singular values capture tau of Schatten-1 mass
      2) quantize those top-r values uniformly
      3) report (#distinct quantized values) / r

    This measures how many *important* modes remain distinguishable after coarse quantization.
    """
    s = np.sort(np.abs(sigma))[::-1]
    r = effective_rank_s1(s, tau=tau)
    if r <= 1:
        return 1.0
    top = s[:r]
    q = uniform_quantize(top, bits=bits, xmin=float(top.min()), xmax=float(top.max()))
    distinct = np.unique(q).size
    return float(distinct / r)


# -----------------------------
# Iterative stability demo
# -----------------------------

def random_orthogonal(n: int, rng: np.random.Generator) -> np.ndarray:
    """Random orthogonal matrix via QR of a Gaussian matrix."""
    a = rng.standard_normal((n, n))
    q, r = np.linalg.qr(a)
    # Fix sign ambiguity for determinism
    d = np.sign(np.diag(r))
    d[d == 0] = 1.0
    q = q * d
    return q


def build_operator_from_spectrum(sigma: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    """
    Build A = Q diag(sigma) Q^T (symmetric) to keep the spectrum while adding mixing.
    """
    n = sigma.size
    q = random_orthogonal(n, rng)
    return (q * sigma) @ q.T  # q diag(sigma) q^T


def iterative_relative_error(
    A: np.ndarray,
    steps: int = 1000,
    seed_vec: int = 0,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compare float32 vs float64 iteration:
        v_{t+1} = A v_t
    Return (iters, rel_errors), where rel_errors[t] = ||v32 - v64|| / ||v64||.
    """
    rng = np.random.default_rng(seed_vec)
    n = A.shape[0]
    v0 = rng.standard_normal(n).astype(np.float64)
    v64 = v0.copy()
    v32 = v0.astype(np.float32)

    A64 = A.astype(np.float64)
    A32 = A.astype(np.float32)

    iters = np.arange(1, steps + 1)
    rel = np.zeros(steps, dtype=np.float64)

    for t in range(steps):
        v64 = A64 @ v64
        v32 = A32 @ v32
        # Promote v32 for comparison
        v32_f64 = v32.astype(np.float64)
        denom = np.linalg.norm(v64) + 1e-30
        rel[t] = np.linalg.norm(v32_f64 - v64) / denom

    return iters, rel


# -----------------------------
# Simple denoising via spectral filtering
# -----------------------------

def denoise_benchmark(alpha: float, beta: float, n_fft: int = 512, seed: int = 0) -> Dict[str, float]:
    """
    Create a clean 5 Hz sine wave, add Gaussian noise, and apply two frequency-domain filters
    whose magnitude responses follow algebraic vs exponential decay.

    Returns RMSEs for each filter.
    """
    rng = np.random.default_rng(seed)
    fs = 200.0
    t = np.arange(0, 2.0, 1.0 / fs)  # 2 seconds
    clean = np.sin(2 * math.pi * 5.0 * t)
    noisy = clean + 0.5 * rng.standard_normal(clean.shape)

    # FFT
    X = np.fft.rfft(noisy, n=n_fft)
    freqs = np.fft.rfftfreq(n_fft, d=1.0 / fs)

    # Build filter responses on frequency bins (avoid division by zero at DC)
    k = np.arange(freqs.size, dtype=np.float64)
    k[0] = 1.0

    H_alg = k ** (-alpha)
    H_exp = np.exp(-beta * (k - 1.0))

    # Normalize to keep DC gain = 1
    H_alg = H_alg / H_alg[0]
    H_exp = H_exp / H_exp[0]

    # Apply filters
    y_alg = np.fft.irfft(X * H_alg, n=n_fft)[: clean.size]
    y_exp = np.fft.irfft(X * H_exp, n=n_fft)[: clean.size]

    rmse_alg = float(np.sqrt(np.mean((y_alg - clean) ** 2)))
    rmse_exp = float(np.sqrt(np.mean((y_exp - clean) ** 2)))

    return {"rmse_algebraic": rmse_alg, "rmse_exponential": rmse_exp}


# -----------------------------
# Main
# -----------------------------

@dataclass
class Config:
    n: int = 100
    alpha: float = 0.5
    beta: float = 0.15
    tau: float = 0.9
    bits: int = 4
    steps: int = 1000
    seed: int = 0


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--n", type=int, default=100, help="Matrix dimension (synthetic)")
    p.add_argument("--alpha", type=float, default=0.5, help="Algebraic decay exponent")
    p.add_argument("--beta", type=float, default=0.15, help="Exponential decay rate")
    p.add_argument("--tau", type=float, default=0.9, help="Trace-norm mass fraction for effective rank")
    p.add_argument("--bits", type=int, default=4, help="Quantization bits for eigenvalue clarity")
    p.add_argument("--steps", type=int, default=1000, help="Iteration steps for stability demo")
    p.add_argument("--seed", type=int, default=0, help="Random seed")
    p.add_argument("--out_prefix", type=str, default="spectral_compactness", help="Output prefix for figures")
    args = p.parse_args()

    cfg = Config(
        n=args.n, alpha=args.alpha, beta=args.beta, tau=args.tau,
        bits=args.bits, steps=args.steps, seed=args.seed
    )

    sig_alg = algebraic_spectrum(cfg.n, alpha=cfg.alpha)
    sig_exp = exponential_spectrum(cfg.n, beta=cfg.beta)

    # Metrics
    metrics: Dict[str, Any] = {}
    metrics["effective_rank_alg"] = effective_rank_s1(sig_alg, tau=cfg.tau)
    metrics["effective_rank_exp"] = effective_rank_s1(sig_exp, tau=cfg.tau)
    metrics["entropy_alg"] = spectral_entropy(sig_alg)
    metrics["entropy_exp"] = spectral_entropy(sig_exp)
    metrics["clarity_alg"] = eigenvalue_clarity(sig_alg, bits=cfg.bits, tau=cfg.tau)
    metrics["clarity_exp"] = eigenvalue_clarity(sig_exp, bits=cfg.bits, tau=cfg.tau)

    # Iterative stability: compare float32 vs float64 relative error
    rng = np.random.default_rng(cfg.seed)
    A_alg = build_operator_from_spectrum(sig_alg, rng)
    A_exp = build_operator_from_spectrum(sig_exp, rng)

    it_alg, rel_alg = iterative_relative_error(A_alg, steps=cfg.steps, seed_vec=cfg.seed + 1)
    it_exp, rel_exp = iterative_relative_error(A_exp, steps=cfg.steps, seed_vec=cfg.seed + 1)

    # Denoising benchmark
    den = denoise_benchmark(alpha=cfg.alpha, beta=cfg.beta, seed=cfg.seed)

    # Print summary
    print("\n=== Spectral compactness demo (synthetic) ===")
    print(f"n={cfg.n}, alpha={cfg.alpha}, beta={cfg.beta}, tau={cfg.tau}, bits={cfg.bits}, steps={cfg.steps}\n")
    print("Effective rank (trace-norm mass):")
    print(f"  algebraic:   {metrics['effective_rank_alg']}")
    print(f"  exponential: {metrics['effective_rank_exp']}\n")
    print("Spectral entropy (nats, p_i ∝ sigma_i):")
    print(f"  algebraic:   {metrics['entropy_alg']:.4f}")
    print(f"  exponential: {metrics['entropy_exp']:.4f}\n")
    print(f"Eigenvalue clarity @ {cfg.bits}-bit quantization (top r_eff distinguishability):")
    print(f"  algebraic:   {100*metrics['clarity_alg']:.1f}%")
    print(f"  exponential: {100*metrics['clarity_exp']:.1f}%\n")
    print("Denoising benchmark RMSE (frequency-domain tail-shaped filters):")
    print(f"  algebraic:   {den['rmse_algebraic']:.4f}")
    print(f"  exponential: {den['rmse_exponential']:.4f}\n")

    # Plot 1: spectra
    plt.figure()
    plt.semilogy(np.arange(1, cfg.n + 1), sig_alg, marker="o", markersize=2, linewidth=1, label="Algebraic")
    plt.semilogy(np.arange(1, cfg.n + 1), sig_exp, marker="o", markersize=2, linewidth=1, label="Exponential")
    plt.xlabel("Index i")
    plt.ylabel("Singular value σ_i (log scale)")
    plt.title("Singular-spectrum tail shapes")
    plt.legend()
    fig1 = f"{args.out_prefix}_spectrum.png"
    plt.savefig(fig1, dpi=200, bbox_inches="tight")
    print(f"Saved: {fig1}")

    # Plot 2: iterative relative error
    plt.figure()
    plt.loglog(it_alg, rel_alg + 1e-30, marker="o", markersize=2, linewidth=1, label="Algebraic tail")
    plt.loglog(it_exp, rel_exp + 1e-30, marker="o", markersize=2, linewidth=1, label="Exponential tail")
    plt.xlabel("Iteration step t")
    plt.ylabel("Relative error ||v32 - v64|| / ||v64|| (log-log)")
    plt.title("Finite-precision error accumulation (float32 vs float64)")
    plt.legend()
    fig2 = f"{args.out_prefix}_iter_error.png"
    plt.savefig(fig2, dpi=200, bbox_inches="tight")
    print(f"Saved: {fig2}")

    print("\nDone.\n")


if __name__ == "__main__":
    main()
