"""
2D Neutron-3He Scattering Simulation
Nuclear force proportional absorption model with 1/r^4 potential
"""
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2, ifft2, fftfreq
from datetime import datetime

# --- Physical Constants ---
m_n = 1.675e-27       # Neutron mass [kg]
hbar = 1.055e-34      # Reduced Planck constant [J·s]
NA = 1.935e-71        # Nuclear potential strength [J·m^4]
V_core = -3.297e-12   # Core potential at r=0 [J]

# --- Energy Selection ---
ENERGY_EV = 1000  # Energy setting: 250, 1000, or 4000 eV

# --- Energy-dependent Parameters ---
if ENERGY_EV == 250:
    E_J = 250 * 1.602e-19
    dt = 20e-22
    t_max_zs = 48000
    snap_zs = [0, 8000, 16000, 32000, 48000]
    r_max, z_min, z_max = 20e-12, -10e-12, 10e-12
    sigma_z = 2000e-15
    z0 = -5e-12
    Nx, Nz = 2293, 1380  # Odd Nx for grid at x=0

elif ENERGY_EV == 1000:
    E_J = 1000 * 1.602e-19
    dt = 5e-22
    t_max_zs = 12000
    snap_zs = [0, 2000, 4000, 8000, 12000]
    r_max, z_min, z_max = 10e-12, -5e-12, 5e-12
    sigma_z = 1000e-15
    z0 = -2.5e-12
    Nx, Nz = 1621, 976  # Odd Nx for grid at x=0

else:  # ENERGY_EV == 4000 or default
    E_J = 4000 * 1.602e-19
    dt = 1.25e-22
    t_max_zs = 3000
    snap_zs = [0, 500, 1000, 2000, 3000]
    r_max, z_min, z_max = 5e-12, -2.5e-12, 2.5e-12
    sigma_z = 500e-15
    z0 = -1.25e-12
    Nx, Nz = 1147, 690  # Odd Nx for grid at x=0
    if ENERGY_EV != 4000:
        print(f"Warning: Energy {ENERGY_EV} eV not defined, using 4000 eV parameters")
        ENERGY_EV = 4000
        E_J = 4000 * 1.602e-19

# Derived parameters
k0 = np.sqrt(2*m_n*E_J) / hbar
nt = int(t_max_zs / (dt*1e21))
print_interval = max(1, nt//10)
snap_steps = {int(t/(dt*1e21)): t for t in snap_zs}

print(f"=== Simulating {ENERGY_EV} eV neutron scattering ===")
print(f"Wave vector k0 = {k0:.3e} 1/m")
print(f"de Broglie wavelength = {2*np.pi/k0*1e15:.1f} fm")

# --- Grid Setup ---
x = np.linspace(-r_max, r_max, Nx)
z = np.linspace(z_min, z_max, Nz)
dx, dz = x[1]-x[0], z[1]-z[0]
X, Z = np.meshgrid(x, z, indexing='xy')
r_grid = np.sqrt(X**2 + Z**2)

print(f"Grid: {Nx}×{Nz} points")
print(f"Resolution: dx={dx*1e15:.2f} fm, dz={dz*1e15:.2f} fm")

# --- Nuclear Potential Setup ---
def V_nuclear_2D(r_grid):
    """Nuclear potential with special value at r=0"""
    V = np.zeros_like(r_grid, dtype=complex)

    # Identify r=0 points (within numerical tolerance)
    mask_zero = (r_grid < dx/10)
    mask_nonzero = ~mask_zero

    # Complex potential for absorption
    # Real part: scattering (set to 0 for pure absorption)
    V.real = 0

    # Imaginary part: absorption
    V.imag[mask_zero] = V_core
    V.imag[mask_nonzero] = -NA / r_grid[mask_nonzero]**4

    return V

V = V_nuclear_2D(r_grid)

# --- Wave number space ---
kx = 2*np.pi * fftfreq(Nx, d=dx)
kz = 2*np.pi * fftfreq(Nz, d=dz)
KX, KZ = np.meshgrid(kx, kz, indexing='xy')

# --- Initial Wave Function ---
def init_psi():
    """Gaussian wave packet"""
    psi = np.exp(-((Z - z0)**2)/(2*sigma_z**2)) * np.exp(1j*k0*Z)
    psi /= np.sqrt(np.sum(np.abs(psi)**2)*dx*dz)
    return psi

# --- Split-operator Method ---
def potential_step(psi, dt_half):
    return psi * np.exp(-1j * V * dt_half / hbar)

def kinetic_step(psi, dt_full):
    psi_k = fft2(psi)
    psi_k *= np.exp(-1j * hbar * (KX**2 + KZ**2) * dt_full / (2*m_n))
    return ifft2(psi_k)

# --- Main Simulation ---
print("\n=== Starting simulation ===")
psi_full = init_psi()
psi_inc = init_psi()
results = []

for step in range(nt+1):
    current_time_zs = step * dt * 1e21

    # Progress display
    if step % print_interval == 0:
        norm_current = np.sum(np.abs(psi_full)**2)*dx*dz
        print(f"Progress: {step}/{nt} ({100*step/nt:.1f}%), t={current_time_zs:.0f} zs, Norm={norm_current:.6e}")

    # Snapshot analysis
    if step in snap_steps:
        t_zs = snap_steps[step]
        norm_total = np.sum(np.abs(psi_full)**2)*dx*dz

        # Extract scattered wave
        phase = np.angle(np.vdot(psi_inc.ravel(), psi_full.ravel()))
        psi_inc_aligned = psi_inc * np.exp(-1j*phase)
        psi_sc = psi_full - psi_inc_aligned

        # 2D probabilities
        R2_2d = np.sum(np.abs(psi_sc)**2)*dx*dz      # Scattering probability
        A2_2d = 1.0 - norm_total                     # Absorption probability
        T2_2d = norm_total - R2_2d                   # Transmission probability

        # Store results
        result = {
            'time_zs': t_zs,
            'R2_2d': R2_2d,
            'A2_2d': A2_2d,
            'T2_2d': T2_2d,
            'norm_total': norm_total
        }
        results.append(result)

        print(f"[t={t_zs} zs] Scattering={R2_2d:.3e}, Absorption={A2_2d:.3e}, Transmission={T2_2d:.3e}")

    # Time evolution
    if step < nt:
        psi_full = potential_step(psi_full, dt/2)
        psi_full = kinetic_step(psi_full, dt)
        psi_full = potential_step(psi_full, dt/2)
        psi_inc = kinetic_step(psi_inc, dt)

# --- Final Results ---
print("\n=== Final Results ===")
final_result = results[-1]

print(f"Energy: {ENERGY_EV} eV")
print(f"Final normalization: {final_result['norm_total']:.6e}")
print(f"Scattering probability: {final_result['R2_2d']:.6e}")
print(f"Absorption probability: {final_result['A2_2d']:.6e}")
print(f"Transmission probability: {final_result['T2_2d']:.6e}")
print(f"Sum check: {final_result['R2_2d'] + final_result['A2_2d'] + final_result['T2_2d']:.6e}")

# --- Visualization (論文投稿用) ---
# Final scattering pattern
phase_final = np.angle(np.vdot(psi_inc.ravel(), psi_full.ravel()))
psi_inc_aligned_final = psi_inc * np.exp(-1j*phase_final)
psi_sc_final = psi_full - psi_inc_aligned_final

# 縦に並べる: 2行1列
# 論文投稿用に大きいサイズ（縦横2倍、面積4倍）
# aspect='equal'でデータの1単位を縦横同じ長さで表示（円形の散乱パターンが正しく円形に見える）
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

# Total wave function
img1 = ax1.imshow(np.abs(psi_full)**2,
                  extent=[x.min()*1e15, x.max()*1e15, z.min()*1e15, z.max()*1e15],
                  aspect='equal',  # データ単位で等しい比率（円形が円形に見える）
                  origin='lower', cmap='viridis')
ax1.set_title(f'Total |ψ|² at {ENERGY_EV} eV (t={t_max_zs} zs)', fontsize=12)
ax1.set_xlabel('x [fm]', fontsize=11)
ax1.set_ylabel('z [fm]', fontsize=11)
ax1.tick_params(axis='both', labelsize=10)
cbar1 = plt.colorbar(img1, ax=ax1)
cbar1.ax.tick_params(labelsize=10)

# Scattered wave
img2 = ax2.imshow(np.abs(psi_sc_final)**2,
                  extent=[x.min()*1e15, x.max()*1e15, z.min()*1e15, z.max()*1e15],
                  aspect='equal',  # データ単位で等しい比率（円形が円形に見える）
                  origin='lower', cmap='plasma')
ax2.set_title(f'Scattered |ψ_sc|²', fontsize=12)
ax2.set_xlabel('x [fm]', fontsize=11)
ax2.set_ylabel('z [fm]', fontsize=11)
ax2.tick_params(axis='both', labelsize=10)
cbar2 = plt.colorbar(img2, ax=ax2)
cbar2.ax.tick_params(labelsize=10)

# Add text with results
textstr = f'Scattering: {final_result["R2_2d"]:.3e}\nAbsorption: {final_result["A2_2d"]:.3e}'
ax2.text(0.95, 0.95, textstr, transform=ax2.transAxes,
         verticalalignment='top', horizontalalignment='right',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
         fontsize=12)

plt.tight_layout()

# Save figure with 300 dpi
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"neutron_3He_{ENERGY_EV}eV_{timestamp}.png"
plt.savefig(filename, dpi=300, bbox_inches='tight')
print(f"Figure saved: {filename} at 300 dpi")

plt.show()

# 2. Google Driveに保存(オプション)
try:
    from google.colab import drive
    drive.mount('/content/drive')

    # Driveにもコピーを保存
    import shutil
    drive_path = f'/content/drive/MyDrive/neutron_simulation/{filename}'
    shutil.copy(filename, drive_path)
    print(f"Saved to Drive: {drive_path}")
except:
    print("Drive mount skipped")

# 3. PCに直接ダウンロード
try:
    from google.colab import files
    files.download(filename)
    print(f"Downloaded: {filename}")
except:
    print("Not in Colab environment")

print("\n=== Simulation completed ===")