import numpy as np
import matplotlib.pyplot as plt
from qutip import *

# --------------------------
# 全局设置（APS期刊规范）
# --------------------------
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 10,
    "axes.labelsize": 11,
    "axes.titlesize": 11,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "figure.dpi": 300,
    "figure.figsize": (8, 3),
    "savefig.bbox": "tight",
    "savefig.dpi": 300
})

# --------------------------
# 模型参数（与论文完全一致）
# --------------------------
omega = 1.0       # 两能级跃迁频率
tau_A = 10.0      # 分支A的proper time
gamma_list = np.linspace(0, 0.5, 60)       # 阻尼率扫描范围
delta_tau_list = np.linspace(0, 25, 60)    # 固有时差扫描范围（覆盖完整振荡周期）
gamma_fixed = 0.1  # 切片图固定阻尼率

# 初始内态：(|g> + |e>)/√2
psi0 = (basis(2, 0) + basis(2, 1)).unit()

# --------------------------
# 严格解析计算（正确公式，满足完备性）
# --------------------------
def visibility_echo_analytic_final(gamma, delta_tau):
    """最终正确的回波可见度解析公式，严格满足Δτ=0时Γ=1，Kraus完备性"""
    tau_B = tau_A + delta_tau
    q_A = np.exp(-gamma * tau_A / 2)
    q_B = np.exp(-gamma * tau_B / 2)
    p_A = 1 - q_A
    p_B = 1 - q_B
    
    # 四项贡献（正确的相位与振幅）
    term00 = np.sqrt(q_A * q_B)  # 0跳贡献
    term10 = 0.5 * np.sqrt(p_A * q_A * p_B * q_B) * np.exp(-1j * omega * delta_tau / 4)  # 1跳-第一段衰变
    term01 = 0.5 * np.sqrt(p_A * p_B) * np.exp(1j * omega * delta_tau / 4)  # 1跳-第二段衰变
    term11 = 0.5 * p_A * p_B  # 2跳贡献
    
    kappa = term00 + term10 + term01 + term11
    return np.abs(kappa)

def echo_upper_bound_final(gamma, delta_tau):
    """紧致的三角不等式上界，相位对齐时取等号"""
    tau_B = tau_A + delta_tau
    q_A = np.exp(-gamma * tau_A / 2)
    q_B = np.exp(-gamma * tau_B / 2)
    p_A = 1 - q_A
    p_B = 1 - q_B
    
    bound = np.sqrt(q_A*q_B) + 0.5*np.sqrt(p_A*q_A*p_B*q_B) + 0.5*np.sqrt(p_A*p_B) + 0.5*p_A*p_B
    return bound

# --------------------------
# 修正后的QuTiP仿真（正确的正交Kraus算子）
# --------------------------
def echo_kraus_final(gamma, tau):
    """正确的回波Kraus算子，严格满足完备性，物理过程正确"""
    q = np.exp(-gamma * tau / 2)
    p = 1 - q
    # 4类正交Kraus算子，带正确相位
    K00 = np.sqrt(q) * sigmax()
    K10 = np.sqrt(p*q) * np.exp(-1j * omega * tau /4) * basis(2,1) * basis(2,1).dag()
    K01 = np.sqrt(p) * np.exp(1j * omega * tau /4) * basis(2,0) * basis(2,0).dag()
    K11 = p * basis(2,0) * basis(2,1).dag()
    return [K00, K10, K01, K11]

def visibility_echo_qutip_final(gamma, delta_tau):
    """严格按相同跳变历史的Kraus匹配求和计算可见度"""
    tau_B = tau_A + delta_tau
    K_A = echo_kraus_final(gamma, tau_A)
    K_B = echo_kraus_final(gamma, tau_B)
    
    # 仅相同跳变历史的算子求和，交叉项环境态正交，贡献为0
    kappa = 0j
    for Ka, Kb in zip(K_A, K_B):
        kappa += (psi0.dag() * Ka.dag() * Kb * psi0)
    return np.abs(kappa)

# --------------------------
# 生成最终相图
# --------------------------
Gamma_echo_map = np.zeros((len(gamma_list), len(delta_tau_list)))
for i, gamma in enumerate(gamma_list):
    for j, dt in enumerate(delta_tau_list):
        Gamma_echo_map[i,j] = visibility_echo_analytic_final(gamma, dt)

fig1, ax1 = plt.subplots(figsize=(4, 3))
contour = ax1.contourf(delta_tau_list, gamma_list, Gamma_echo_map, 50, cmap='viridis', vmin=0, vmax=1)
cbar = fig1.colorbar(contour, ax=ax1)
cbar.set_label(r'$\Gamma_{\mathrm{echo}}$')
ax1.set_xlabel(r'$\Delta\tau$')
ax1.set_ylabel(r'$\gamma$')
ax1.set_title(r'Echo-recovered visibility')
ax1.axvline(x=0, color='white', linestyle='--', linewidth=1, label=r'$\Delta\tau=0, \Gamma=1$')
ax1.legend(fontsize=7)
fig1.savefig('fig_echo_phasemap_final.png')

# 恒等式验证
print("=== 核心恒等式验证（Δτ=0时Γ=1） ===")
for gamma_test in [0.0, 0.1, 0.3, 0.5]:
    val = visibility_echo_analytic_final(gamma_test, 0)
    print(f"γ={gamma_test:.1f}, 解析可见度={val:.10f}")

# --------------------------
# 生成最终切片图
# --------------------------
dt_slice = np.linspace(0, 25, 100)
Gamma_echo_analytic = []
Gamma_echo_bound = []
Gamma_echo_num = []

for dt in dt_slice:
    Gamma_echo_analytic.append(visibility_echo_analytic_final(gamma_fixed, dt))
    Gamma_echo_bound.append(echo_upper_bound_final(gamma_fixed, dt))
    # 每2个点做一次数值验证，加速计算
    if dt % 2 == 0:
        Gamma_echo_num.append(visibility_echo_qutip_final(gamma_fixed, dt))
    else:
        Gamma_echo_num.append(np.nan)

fig2, ax2 = plt.subplots(figsize=(4, 3))
ax2.plot(dt_slice, Gamma_echo_analytic, label='Exact analytic', linewidth=1.5, color='#2ca02c')
ax2.plot(dt_slice, Gamma_echo_num, 'o', label='Numerical (Kraus)', markersize=4, color='#1f77b4')
ax2.plot(dt_slice, Gamma_echo_bound, '--', label='Analytic bound', linewidth=1.5, color='#ff7f0e')
ax2.set_xlabel(r'$\Delta\tau$')
ax2.set_ylabel(r'$\Gamma_{\mathrm{echo}}$')
ax2.set_title(r'$\gamma=0.1$, $\tau_A=10$')
ax2.legend(framealpha=1)
ax2.set_ylim(0.2, 1.05)
fig2.savefig('fig_echo_bound_final.png')

# 数值与解析匹配验证
print("\n=== 数值与解析匹配验证 ===")
for dt_test in [0, 5, 10, 15, 20]:
    ana_val = visibility_echo_analytic_final(gamma_fixed, dt_test)
    num_val = visibility_echo_qutip_final(gamma_fixed, dt_test)
    print(f"Δτ={dt_test:.0f}, 解析={ana_val:.4f}, 数值={num_val:.4f}, 差值={abs(ana_val-num_val):.6f}")

plt.show()