"""
Analytical Framework v4: TID Routing Attack — Paper-Matched Parameters
=======================================================================

"""

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

#core functions

def h(p):
    """Binary entropy function h(p) = -p log2(p) - (1-p) log2(1-p)."""
    if p <= 1e-15 or p >= 1 - 1e-15:
        return 0.0
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)


def capacity_thermal(nth):
    """TID information recording capacity for a thermal mode (Eq. 2)."""
    return 1.0 / (2 * np.pi * (2 * nth + 2))


# physical constants
h_planck = 6.626e-34   # Planck's constant (J·s)
kB = 1.38e-23          # Boltzmann constant (J/K)


def nth_bose(freq, temp):
    """Bose-Einstein thermal occupation number (Eq. 1)."""
    x = h_planck * freq / (kB * temp)
    if x > 500:
        return 0.0
    return 1.0 / (np.exp(x) - 1)


# CHANNEL PARAMETERS 

FREQ = 5e9              # 5 GHz microwave
T_ENV = 300             # Room temperature environment (K)
T_ADV = 0.020           # Adversary's cryostat (20 mK)
LINK_LOSS_DB = 10.0     # 10 dB channel loss (paper's primary scenario)
QBER_BASE = 0.005       # 0.5% QBER (paper's primary scenario)

# Derived quantities
ETA = 10**(-LINK_LOSS_DB / 10)
P_LOST = 1 - ETA

nth_env = nth_bose(FREQ, T_ENV)    # ~1249
nth_adv = nth_bose(FREQ, T_ADV)    # ~0
C_env = capacity_thermal(nth_env)
C_adv = capacity_thermal(nth_adv)
C_RATIO = C_adv / C_env


# ANALYSIS FUNCTION


def analyse_attack(f_int, loss_db=LINK_LOSS_DB, qber=QBER_BASE, use_tid=True):
    """
    Full per-bit analysis for given parameters.

    Parameters
    ----------
    f_int : float
        Adversary's coupling fraction (fraction of total env coupling).
    loss_db : float
        Channel loss in dB.
    qber : float
        Quantum bit error rate.
    use_tid : bool
        If True, use TID capacity-weighted routing; if False, standard routing.

    Returns
    -------
    dict with all computed quantities.
    """
    # Channel loss
    eta = 10**(-loss_db / 10)
    p_lost = 1 - eta

    # Routing fractions (Eqs. 4-5)
    f_route_std = f_int
    w_adv = C_adv * f_int
    w_env = C_env * (1.0 - f_int)
    f_route_tid = w_adv / (w_adv + w_env)

    f_route = f_route_tid if use_tid else f_route_std
    enhancement = f_route_tid / f_route_std if f_route_std > 0 else 0

    # Probability adversary captures the photon
    p_capture = p_lost * f_route

    # Helstrom bound verification (Eq. 20 in revised paper)
    # For states rho_0 = |0><0| and rho_1 = (1-p)|0><0| + p|1><1|:
    # ||rho_0 - rho_1||_1 = p_capture
    # P_err_Helstrom = (1 - p_capture/2) / 2
    p_err_helstrom = (1 - p_capture / 2) / 2

    # Bayesian analysis (Eqs. 8-10)
    p_click = 0.5 * p_capture
    p_no_click = 1.0 - p_click
    p0_given_no_click = 0.5 / (1.0 - p_capture / 2.0)

    # Mutual information I(K_Z; E) (Eqs. 11-13)
    H_K_given_no_click = h(p0_given_no_click)
    H_K_given_E = p_click * 0 + p_no_click * H_K_given_no_click
    I_Z = 1.0 - H_K_given_E

    # X-basis: no information (Eq. 14)
    I_X = 0.0

    # Sifted key information (Eq. 15)
    I_sifted = 0.5 * I_Z + 0.5 * I_X

    # Key rate analysis (Eqs. 16-18)
    h_qber = h(qber)
    R_std_estimate = 1.0 - 2.0 * h_qber       # What Alice & Bob compute
    R_actual = max(0, 1.0 - h_qber - I_sifted)  # Actual secure rate

    # Security gap (Eq. 18): positive = adversary wins
    security_gap = I_sifted - h_qber
    uncompensated = max(0, security_gap)

    return {
        'f_int': f_int,
        'loss_db': loss_db,
        'qber': qber,
        'p_lost': p_lost,
        'f_route_std': f_route_std,
        'f_route_tid': f_route_tid,
        'f_route': f_route,
        'enhancement': enhancement,
        'p_capture': p_capture,
        'p_err_helstrom': p_err_helstrom,
        'p_click': p_click,
        'I_Z': I_Z,
        'I_X': I_X,
        'I_sifted': I_sifted,
        'h_qber': h_qber,
        'R_std': R_std_estimate,
        'R_actual': R_actual,
        'gap': security_gap,
        'uncompensated': uncompensated,
    }



#  HEADER


print("=" * 80)
print("TID ROUTING ATTACK: MICROWAVE QUANTUM CHANNEL")
print("Analytical Framework v4 — Paper-Matched Parameters")
print("=" * 80)
print(f"\nChannel: {FREQ/1e9:.0f} GHz microwave link, {LINK_LOSS_DB:.0f} dB loss")
print(f"Environment: T = {T_ENV} K, nth = {nth_env:.1f}")
print(f"Adversary mode: T = {T_ADV*1000:.0f} mK, nth = {nth_adv:.6f}")
print(f"Capacity ratio C_adv/C_env: {C_RATIO:.1f}")
print(f"Channel transmittance: eta = {ETA:.4f}")
print(f"P(photon lost): {P_LOST:.4f}")
print(f"Baseline QBER: {QBER_BASE}")





# TABLE 1: QBER SCAN (10 dB, f_int=0.01)

print("\n" + "=" * 80)
print("TABLE 1: Security breach analysis at 10 dB loss, f_int = 0.01, nu = 5 GHz")
print("=" * 80)

print(f"\n{'QBER':>7} {'h(QBER)':>8} {'I_sift':>7} {'Delta_R':>8} {'Status':>8}")
print("-" * 42)

qber_values = [0.001, 0.002, 0.005, 0.010, 0.020, 0.050, 0.060, 0.065, 0.080, 0.110]
for qber in qber_values:
    r = analyse_attack(0.01, loss_db=10, qber=qber, use_tid=True)
    status = "Breach" if r['gap'] > 0 else "Secure"
    print(f"{qber:7.3f} {r['h_qber']:8.3f} {r['I_sifted']:7.3f} {r['gap']:8.3f} {status:>8}")

# Find exact crossover
from scipy.optimize import brentq
r_ref = analyse_attack(0.01, loss_db=10, qber=0.005, use_tid=True)
I_sift_10dB = r_ref['I_sifted']

def crossover_func(q):
    return h(q) - I_sift_10dB

qber_crossover = brentq(crossover_func, 0.01, 0.11)
print(f"\nCrossover QBER (exact): {qber_crossover:.4f} ({qber_crossover*100:.2f}%)")
print(f"I_sift at 10 dB, f_int=0.01: {I_sift_10dB:.4f}")


# TABLE 2: LOSS SCAN (QBER=0.5%, f_int=0.01)

print("\n" + "=" * 80)
print("TABLE 2: Security gap vs channel loss at QBER = 0.5%, f_int = 0.01")
print("=" * 80)

print(f"\n{'Loss':>6} {'p_lost':>7} {'p_capt':>7} {'I_sift':>7} {'Delta_R':>8} {'Bits/256':>9}")
print("-" * 48)

for loss_db in [1, 2, 5, 10, 15, 20]:
    r = analyse_attack(0.01, loss_db=loss_db, qber=0.005, use_tid=True)
    bits = 256 * r['uncompensated']
    print(f"{loss_db:6d} {r['p_lost']:7.3f} {r['p_capture']:7.3f} "
          f"{r['I_sifted']:7.3f} {r['gap']:8.3f} {bits:9.1f}")


# TABLE 3: COUPLING FRACTION SCAN (10 dB, QBER=0.5%)

print("\n" + "=" * 80)
print("TABLE 3: TID routing enhancement vs coupling fraction (10 dB, QBER = 0.5%)")
print("=" * 80)

print(f"\n{'f_int':>6} {'f_std':>6} {'f_tid':>6} {'Enh':>6} {'I_sift':>7} {'Delta_R':>8}")
print("-" * 43)

for f_int in [0.001, 0.005, 0.010, 0.020, 0.050, 0.100]:
    r = analyse_attack(f_int, loss_db=10, qber=0.005, use_tid=True)
    print(f"{f_int:6.3f} {r['f_route_std']:6.3f} {r['f_route_tid']:6.3f} "
          f"{r['enhancement']:5.0f}x {r['I_sifted']:7.3f} {r['gap']:8.3f}")


# HELSTROM BOUND VERIFICATION


print("\n" + "=" * 80)
print("HELSTROM BOUND VERIFICATION")
print("=" * 80)

r_hel = analyse_attack(0.01, loss_db=10, qber=0.005, use_tid=True)
print(f"\nAt f_int=0.01, 10 dB loss:")
print(f"  p_capture = {r_hel['p_capture']:.4f}")
print(f"  ||rho_0 - rho_1||_1 = p_capture = {r_hel['p_capture']:.4f}")
print(f"  P_err (Helstrom) = {r_hel['p_err_helstrom']:.4f}")
print(f"  P_err (photon counting) = {r_hel['p_err_helstrom']:.4f} (same — optimal)")
print(f"  Photon number measurement achieves the Helstrom bound.")



# KEY FINDINGS


print("\n" + "=" * 80)
print("KEY FINDINGS")
print("=" * 80)

r01 = analyse_attack(0.01, loss_db=10, qber=0.005, use_tid=True)
r01s = analyse_attack(0.01, loss_db=10, qber=0.005, use_tid=False)

print(f"\n1. TID ROUTING ENHANCEMENT")
print(f"   At f_int = 0.01 (1% coupling), 10 dB loss:")
print(f"   Standard routing: {r01s['f_route']:.4f} ({r01s['f_route']*100:.2f}%)")
print(f"   TID routing:      {r01['f_route']:.4f} ({r01['f_route']*100:.2f}%)")
print(f"   Enhancement: {r01['enhancement']:.1f}x")

print(f"\n2. INFORMATION LEAKAGE")
print(f"   Per Z-basis bit: I_Z = {r01['I_Z']:.4f} bits")
print(f"   Per sifted bit:  I_sift = {r01['I_sifted']:.4f} bits")
print(f"   Standard PA removes: h(QBER) = {r01['h_qber']:.4f} bits")
print(f"   Uncompensated: {r01['uncompensated']:.4f} bits/sifted bit")

print(f"\n3. SECURITY BREACH")
print(f"   Alice & Bob's estimated R: {r01['R_std']:.4f}")
print(f"   Actual R under TID attack: {r01['R_actual']:.4f}")
print(f"   Security gap: {r01['gap']:.4f} bits/use")
print(f"   Crossover QBER: {qber_crossover*100:.1f}% (attack fails above this)")

print(f"\n4. PER-KEY IMPACT (QBER=0.5%, 10 dB, f_int=0.01)")
for N in [128, 256]:
    bits = N * r01['uncompensated']
    print(f"   {N}-bit key: {bits:.1f} uncompensated bits leaked")

print(f"\n5. OPTICAL vs MICROWAVE")
freq_opt = 3e8 / 1550e-9
nth_opt = nth_bose(freq_opt, 300)
x_opt = h_planck * freq_opt / (kB * 300)
print(f"   Optical (1550nm): hv/kT = {x_opt:.1f}, nth = {nth_opt:.2e} -> C_ratio ~ 1")
print(f"   Microwave (5GHz): hv/kT = {h_planck*FREQ/(kB*300):.4f}, nth = {nth_env:.1f} -> C_ratio = {C_RATIO:.0f}")
print(f"   Attack surface: microwave quantum networks ONLY")


# COMPARISON: STANDARD vs TID ADVERSARY


print("\n" + "=" * 80)
print("COMPARISON: STANDARD vs TID ADVERSARY (f_int=0.01, 10 dB, QBER=0.5%)")
print("=" * 80)

rt = analyse_attack(0.01, loss_db=10, qber=0.005, use_tid=True)
rs = analyse_attack(0.01, loss_db=10, qber=0.005, use_tid=False)

print(f"\n  {'Quantity':<35} {'Standard':>10} {'TID':>10}")
print(f"  {'-'*55}")
print(f"  {'Routing to adversary':<35} {rs['f_route']:>10.4f} {rt['f_route']:>10.4f}")
print(f"  {'Photon capture probability':<35} {rs['p_capture']:>10.4f} {rt['p_capture']:>10.4f}")
print(f"  {'I(K_Z; E) per Z-basis bit':<35} {rs['I_Z']:>10.4f} {rt['I_Z']:>10.4f}")
print(f"  {'I(K; E) per sifted bit':<35} {rs['I_sifted']:>10.4f} {rt['I_sifted']:>10.4f}")
print(f"  {'Actual key rate R':<35} {rs['R_actual']:>10.4f} {rt['R_actual']:>10.4f}")
print(f"  {'Uncompensated leakage':<35} {rs['uncompensated']:>10.4f} {rt['uncompensated']:>10.4f}")
print(f"  {'Security gap (I_sift - h(QBER))':<35} {rs['gap']:>10.4f} {rt['gap']:>10.4f}")



# PLOTS


fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# --- Row 1: f_int sweeps at paper parameters ---

f_vals = np.linspace(0.001, 0.3, 300)
d_tid = [analyse_attack(f, loss_db=10, qber=0.005, use_tid=True) for f in f_vals]
d_std = [analyse_attack(f, loss_db=10, qber=0.005, use_tid=False) for f in f_vals]

# (a) Routing fraction
ax = axes[0][0]
ax.plot(f_vals, [d['f_route'] for d in d_std], 'r-', lw=2, label='Standard')
ax.plot(f_vals, [d['f_route'] for d in d_tid], 'b-', lw=2, label='TID')
ax.plot(f_vals, f_vals, 'k--', lw=1, alpha=0.3, label='f = f_int')
ax.set_xlabel('$f_{\\mathrm{int}}$')
ax.set_ylabel('Routing fraction')
ax.set_title('(a) Information Routing to Adversary')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# (b) Mutual information vs f_int
ax = axes[0][1]
ax.plot(f_vals, [d['I_sifted'] for d in d_std], 'r-', lw=2, label='$I_{\\mathrm{sift}}$ Standard')
ax.plot(f_vals, [d['I_sifted'] for d in d_tid], 'b-', lw=2, label='$I_{\\mathrm{sift}}$ TID')
ax.axhline(h(0.005), color='g', ls='--', lw=1.5, label=f'$h$(QBER) = {h(0.005):.3f}')
ax.set_xlabel('$f_{\\mathrm{int}}$')
ax.set_ylabel('Bits per sifted bit')
ax.set_title('(b) Adversary Information vs Coupling')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# (c) Uncompensated leakage per 256-bit key
ax = axes[0][2]
ax.plot(f_vals, [256*d['uncompensated'] for d in d_std], 'r-', lw=2, label='Standard')
ax.plot(f_vals, [256*d['uncompensated'] for d in d_tid], 'b-', lw=2, label='TID')
ax.axhline(0, color='k', lw=0.5)
ax.set_xlabel('$f_{\\mathrm{int}}$')
ax.set_ylabel('Bits')
ax.set_title('(c) Uncompensated Leakage per 256-bit Key')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# --- Row 2: QBER sweep, loss sweep, key rate ---

# (d) QBER sweep (Table 1 visualisation)
qber_range = np.linspace(0.001, 0.12, 200)
i_sift_vals = []
h_qber_vals = []
for q in qber_range:
    r = analyse_attack(0.01, loss_db=10, qber=q, use_tid=True)
    i_sift_vals.append(r['I_sifted'])
    h_qber_vals.append(r['h_qber'])

ax = axes[1][0]
ax.plot(qber_range * 100, i_sift_vals, 'b-', lw=2, label='$I_{\\mathrm{sift}}$ (TID)')
ax.plot(qber_range * 100, h_qber_vals, 'g--', lw=2, label='$h$(QBER)')
ax.axvline(qber_crossover * 100, color='red', ls=':', lw=1.5,
           label=f'Crossover = {qber_crossover*100:.1f}%')
ax.fill_between(qber_range * 100, h_qber_vals, i_sift_vals,
                where=[i > hq for i, hq in zip(i_sift_vals, h_qber_vals)],
                alpha=0.15, color='red', label='Security breach')
ax.set_xlabel('QBER (%)')
ax.set_ylabel('Bits per sifted bit')
ax.set_title('(d) Low-QBER Vulnerability (Table 1)')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# (e) Loss sweep (Table 2 visualisation)
loss_range = np.linspace(0.5, 25, 200)
i_sift_loss = []
uncomp_loss = []
for L in loss_range:
    r = analyse_attack(0.01, loss_db=L, qber=0.005, use_tid=True)
    i_sift_loss.append(r['I_sifted'])
    uncomp_loss.append(256 * r['uncompensated'])

ax = axes[1][1]
ax.plot(loss_range, uncomp_loss, 'b-', lw=2, label='TID adversary')
# Standard adversary for comparison
uncomp_loss_std = []
for L in loss_range:
    r = analyse_attack(0.01, loss_db=L, qber=0.005, use_tid=False)
    uncomp_loss_std.append(256 * r['uncompensated'])
ax.plot(loss_range, uncomp_loss_std, 'r-', lw=2, label='Standard adversary')
ax.axhline(0, color='k', lw=0.5)
ax.set_xlabel('Channel loss (dB)')
ax.set_ylabel('Bits per 256-bit key')
ax.set_title('(e) Leakage vs Channel Loss (Table 2)')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# (f) Key rate comparison
ax = axes[1][2]
R_std_vals = [1 - 2*h(q) for q in qber_range]
R_actual_vals = [max(0, 1 - h(q) - analyse_attack(0.01, loss_db=10, qber=q, use_tid=True)['I_sifted'])
                 for q in qber_range]
ax.plot(qber_range * 100, R_std_vals, 'g--', lw=2, label='$R_{\\mathrm{std}}$ (Alice & Bob estimate)')
ax.plot(qber_range * 100, R_actual_vals, 'b-', lw=2, label='$R_{\\mathrm{actual}}$ (TID attack)')
ax.fill_between(qber_range * 100, R_actual_vals, R_std_vals,
                where=[ra < rs for ra, rs in zip(R_actual_vals, R_std_vals)],
                alpha=0.15, color='red', label='Security gap')
ax.set_xlabel('QBER (%)')
ax.set_ylabel('Key rate (bits/use)')
ax.set_title('(f) Secret Key Rate vs QBER')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1])

plt.suptitle(f'TID Routing Attack on {FREQ/1e9:.0f} GHz Microwave Quantum Channel\n'
             f'Default: Loss={LINK_LOSS_DB:.0f} dB, T_env={T_ENV} K, T_adv={T_ADV*1000:.0f} mK, '
             f'C_ratio={C_RATIO:.0f}, QBER={QBER_BASE}',
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig('/mnt/user-data/outputs/tid_attack_analysis_v4.png', dpi=150, bbox_inches='tight')
print("\n\nPlot saved to /mnt/user-data/outputs/tid_attack_analysis_v4.png")


# LATEX TABLE DATA (for verification)


print("\n" + "=" * 80)
print("LATEX TABLE DATA")
print("=" * 80)

print("\n--- Table 1 (QBER scan) ---")
print(f"QBER & h(QBER) & I_sift & Delta_R & Status \\\\")
for qber in qber_values:
    r = analyse_attack(0.01, loss_db=10, qber=qber, use_tid=True)
    status = "Breach" if r['gap'] > 0 else "Secure"
    dr_str = f"{r['gap']:.3f}" if r['gap'] >= 0 else f"$-${abs(r['gap']):.3f}"
    print(f"{qber:.3f} & {r['h_qber']:.3f} & {r['I_sifted']:.3f} & {dr_str} & {status} \\\\")

print("\n--- Table 2 (Loss scan) ---")
print(f"Loss & p_lost & p_capt & I_sift & Delta_R & Bits/256 \\\\")
for loss_db in [1, 2, 5, 10, 15, 20]:
    r = analyse_attack(0.01, loss_db=loss_db, qber=0.005, use_tid=True)
    bits = 256 * r['uncompensated']
    print(f"{loss_db} & {r['p_lost']:.3f} & {r['p_capture']:.3f} & {r['I_sifted']:.3f} "
          f"& {r['gap']:.3f} & {bits:.1f} \\\\")

print("\n--- Table 3 (Coupling scan) ---")
print(f"f_int & f_std & f_tid & Enhancement & I_sift & Delta_R \\\\")
for f_int in [0.001, 0.005, 0.010, 0.020, 0.050, 0.100]:
    r = analyse_attack(f_int, loss_db=10, qber=0.005, use_tid=True)
    print(f"{f_int:.3f} & {r['f_route_std']:.3f} & {r['f_route_tid']:.3f} & "
          f"{r['enhancement']:.0f}x & {r['I_sifted']:.3f} & {r['gap']:.3f} \\\\")
