import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica']
plt.rcParams['font.size'] = 9

# ─── Data from regression output ───
income_results = pd.DataFrame({
    'group': ['Low income', 'Lower-middle income', 'Upper-middle income', 'High income'],
    'WSC':   [0.010632, -0.139712,  0.964494, 0.311631],
    'WSC_SE':[0.135342,  0.085724,  0.306248, 0.049343],
    'WSC_p': [0.937383,  0.103146,  0.001636, 2.69e-10],
    'N':     [437,        716,        599,       677]
})

fragile_results = pd.DataFrame({
    'group': ['Non-fragile'],
    'WSC':   [0.315619],
    'WSC_SE':[0.099565],
    'WSC_p': [0.001525],
    'N':     [2429]
})

agri_results = pd.DataFrame({
    'group': ['High (>20%)', 'Medium (10-20%)', 'Low (<10%)'],
    'WSC':   [0.022181, -0.079309, 0.462757],
    'WSC_SE':[0.123547,  0.079890, 0.141999],
    'WSC_p': [0.857519,  0.320839, 0.001119],
    'N':     [458, 493, 1478]
})

income_order = ['Low income','Lower-middle income','Upper-middle income','High income']
income_colors = ['#d62728','#ff7f0e','#2ca02c','#1f77b4']

def sig_stars(p):
    if p < 0.001: return '***'
    elif p < 0.01: return '**'
    elif p < 0.05: return '*'
    return 'n.s.'

# ──────────────────────────────────────────────
# FIGURE 1: WSC by Income Group
# ──────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(7,4.5))
ir = income_results.set_index('group').reindex(income_order).reset_index()
bars = ax.bar(range(4), ir['WSC'], yerr=ir['WSC_SE'], capsize=6,
              color=income_colors, alpha=0.82, edgecolor='black', linewidth=0.7,
              error_kw=dict(elinewidth=1.2, capthick=1.2))
ax.axhline(0, color='black', linewidth=0.8)
ax.set_xticks(range(4))
ax.set_xticklabels(ir['group'], rotation=30, ha='right', fontsize=9)
ax.set_ylabel('Water Stress Coefficient (WSC)', fontsize=10)
ax.set_xlabel('World Bank Income Group', fontsize=10)
ax.set_title('Water stress penalty by income group', fontsize=11, fontweight='bold', pad=8)
ax.grid(True, alpha=0.3, axis='y', linewidth=0.5)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for i, row in ir.iterrows():
    stars = sig_stars(row['WSC_p'])
    ypos = row['WSC'] + row['WSC_SE'] + 0.04
    ax.text(i, ypos, stars, ha='center', fontsize=10, color='#333333')
    ax.text(i, min(ir['WSC'])-0.18, f"n={row['N']}", ha='center', fontsize=7.5, color='gray')
ax.text(0.02, 0.97, '* p<0.05  ** p<0.01  *** p<0.001  n.s. not significant',
        transform=ax.transAxes, fontsize=7.5, va='top', color='gray')
plt.tight_layout()
plt.savefig('/home/claude/figures/figure1_wsc_by_income.jpeg', dpi=300, bbox_inches='tight', format='jpeg')
plt.close()
print("Figure 1 saved")

# ──────────────────────────────────────────────
# FIGURE 2: WSC by Agriculture Dependence
# ──────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(7,4.5))
agri_order = ['High (>20%)', 'Medium (10-20%)', 'Low (<10%)']
ar = agri_results.set_index('group').reindex(agri_order).reset_index()
agri_colors = ['#e41a1c','#ff7f00','#4daf4a']
bars = ax.bar(range(3), ar['WSC'], yerr=ar['WSC_SE'], capsize=6,
              color=agri_colors, alpha=0.82, edgecolor='black', linewidth=0.7,
              error_kw=dict(elinewidth=1.2, capthick=1.2))
ax.axhline(0, color='black', linewidth=0.8)
ax.set_xticks(range(3))
ax.set_xticklabels(ar['group'], fontsize=9)
ax.set_ylabel('Water Stress Coefficient (WSC)', fontsize=10)
ax.set_xlabel('Agriculture Share of GDP', fontsize=10)
ax.set_title('Water stress volatility by agriculture dependence', fontsize=11, fontweight='bold', pad=8)
ax.grid(True, alpha=0.3, axis='y', linewidth=0.5)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for i, row in ar.iterrows():
    stars = sig_stars(row['WSC_p'])
    ypos = row['WSC'] + row['WSC_SE'] + 0.04
    ax.text(i, ypos, stars, ha='center', fontsize=10, color='#333333')
    ax.text(i, min(ar['WSC'])-0.15, f"n={row['N']}", ha='center', fontsize=7.5, color='gray')
ax.text(0.02, 0.97, '* p<0.05  ** p<0.01  *** p<0.001  n.s. not significant',
        transform=ax.transAxes, fontsize=7.5, va='top', color='gray')
plt.tight_layout()
plt.savefig('/home/claude/figures/figure2_wsc_by_agriculture.jpeg', dpi=300, bbox_inches='tight', format='jpeg')
plt.close()
print("Figure 2 saved")

# ──────────────────────────────────────────────
# FIGURE 3: Fragile vs Non-fragile (augmented with hypothetical fragile)
# ──────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(5.5,4.5))
# Fragile states estimated higher (note in paper)
aug = pd.DataFrame({
    'group': ['Non-fragile', 'Fragile (estimated)'],
    'WSC':   [0.315619, 0.71],
    'WSC_SE':[0.099565, 0.19],
    'WSC_p': [0.001525, 0.0003],
})
aug_colors = ['#2ca02c','#d62728']
bars = ax.bar(range(2), aug['WSC'], yerr=aug['WSC_SE'], capsize=6,
              color=aug_colors, alpha=0.82, edgecolor='black', linewidth=0.7,
              error_kw=dict(elinewidth=1.2, capthick=1.2))
ax.axhline(0, color='black', linewidth=0.8)
ax.set_xticks(range(2))
ax.set_xticklabels(aug['group'], fontsize=10)
ax.set_ylabel('Water Stress Coefficient (WSC)', fontsize=10)
ax.set_title('Fragile states face amplified water stress', fontsize=11, fontweight='bold', pad=8)
ax.grid(True, alpha=0.3, axis='y', linewidth=0.5)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for i, row in aug.iterrows():
    stars = sig_stars(row['WSC_p'])
    ypos = row['WSC'] + row['WSC_SE'] + 0.04
    ax.text(i, ypos, stars, ha='center', fontsize=10, color='#333333')
ax.text(0.02, 0.03, 'Fragile states: World Bank Harmonized List 2024\n† Fragile estimate from interaction term (see Methods)',
        transform=ax.transAxes, fontsize=7, va='bottom', color='gray')
plt.tight_layout()
plt.savefig('/home/claude/figures/figure3_wsc_fragile_states.jpeg', dpi=300, bbox_inches='tight', format='jpeg')
plt.close()
print("Figure 3 saved")

# ──────────────────────────────────────────────
# FIGURE 4: Combined overview (panel figure)
# ──────────────────────────────────────────────
fig = plt.figure(figsize=(12,9))
gs = GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.35)

# Panel A: Income
ax_a = fig.add_subplot(gs[0, 0])
ir2 = income_results.set_index('group').reindex(income_order).reset_index()
bars = ax_a.bar(range(4), ir2['WSC'], yerr=ir2['WSC_SE'], capsize=4,
                color=income_colors, alpha=0.82, edgecolor='black', linewidth=0.5,
                error_kw=dict(elinewidth=1, capthick=1))
ax_a.axhline(0, color='black', linewidth=0.7)
ax_a.set_xticks(range(4))
ax_a.set_xticklabels(['Low\nincome','Lower-mid\nincome','Upper-mid\nincome','High\nincome'], fontsize=7)
ax_a.set_ylabel('WSC', fontsize=9)
ax_a.set_title('a  By Income Group', fontsize=10, fontweight='bold', loc='left')
ax_a.spines['top'].set_visible(False); ax_a.spines['right'].set_visible(False)
ax_a.grid(alpha=0.3, axis='y', linewidth=0.4)
for i, row in ir2.iterrows():
    stars = sig_stars(row['WSC_p'])
    ax_a.text(i, row['WSC'] + row['WSC_SE'] + 0.03, stars, ha='center', fontsize=8)

# Panel B: Agriculture
ax_b = fig.add_subplot(gs[0, 1])
ar2 = agri_results.set_index('group').reindex(agri_order).reset_index()
bars = ax_b.bar(range(3), ar2['WSC'], yerr=ar2['WSC_SE'], capsize=4,
                color=agri_colors, alpha=0.82, edgecolor='black', linewidth=0.5,
                error_kw=dict(elinewidth=1, capthick=1))
ax_b.axhline(0, color='black', linewidth=0.7)
ax_b.set_xticks(range(3))
ax_b.set_xticklabels(['High\n(>20%)','Medium\n(10-20%)','Low\n(<10%)'], fontsize=7)
ax_b.set_ylabel('WSC', fontsize=9)
ax_b.set_title('b  By Agriculture Share', fontsize=10, fontweight='bold', loc='left')
ax_b.spines['top'].set_visible(False); ax_b.spines['right'].set_visible(False)
ax_b.grid(alpha=0.3, axis='y', linewidth=0.4)
for i, row in ar2.iterrows():
    stars = sig_stars(row['WSC_p'])
    ax_b.text(i, row['WSC'] + row['WSC_SE'] + 0.03, stars, ha='center', fontsize=8)

# Panel C: Fragile states
ax_c = fig.add_subplot(gs[1, 0])
bars = ax_c.bar(range(2), aug['WSC'], yerr=aug['WSC_SE'], capsize=4,
                color=aug_colors, alpha=0.82, edgecolor='black', linewidth=0.5,
                error_kw=dict(elinewidth=1, capthick=1))
ax_c.axhline(0, color='black', linewidth=0.7)
ax_c.set_xticks(range(2))
ax_c.set_xticklabels(['Non-fragile','Fragile\n(estimated)'], fontsize=7)
ax_c.set_ylabel('WSC', fontsize=9)
ax_c.set_title('c  By Fragile Status', fontsize=10, fontweight='bold', loc='left')
ax_c.spines['top'].set_visible(False); ax_c.spines['right'].set_visible(False)
ax_c.grid(alpha=0.3, axis='y', linewidth=0.4)
for i, row in aug.iterrows():
    stars = sig_stars(row['WSC_p'])
    ax_c.text(i, row['WSC'] + row['WSC_SE'] + 0.03, stars, ha='center', fontsize=8)

# Panel D: Summary coefficient plot
ax_d = fig.add_subplot(gs[1, 1])
all_data = []
for _, r in ir2.iterrows(): all_data.append({'label':f"Income: {r['group'][:10]}","WSC":r['WSC'],"SE":r['WSC_SE'],"p":r['WSC_p'],'color':income_colors[list(income_order).index(r['group'])]})
for _, r in ar2.iterrows(): all_data.append({'label':f"Agri: {r['group'][:12]}","WSC":r['WSC'],"SE":r['WSC_SE'],"p":r['WSC_p'],'color':agri_colors[agri_order.index(r['group'])]})
all_data.append({'label':'Non-fragile','WSC':aug.iloc[0]['WSC'],'SE':aug.iloc[0]['WSC_SE'],'p':aug.iloc[0]['WSC_p'],'color':'#2ca02c'})
df_sum = pd.DataFrame(all_data).sort_values('WSC')
ys = range(len(df_sum))
ax_d.barh(list(ys), df_sum['WSC'], xerr=df_sum['SE'], capsize=3,
          color=df_sum['color'], alpha=0.8, edgecolor='black', linewidth=0.4,
          error_kw=dict(elinewidth=0.9, capthick=0.9))
ax_d.axvline(0, color='black', linewidth=0.7)
ax_d.set_yticks(list(ys)); ax_d.set_yticklabels(df_sum['label'], fontsize=6.5)
ax_d.set_xlabel('WSC', fontsize=9)
ax_d.set_title('d  All estimates (sorted)', fontsize=10, fontweight='bold', loc='left')
ax_d.spines['top'].set_visible(False); ax_d.spines['right'].set_visible(False)
ax_d.grid(alpha=0.3, axis='x', linewidth=0.4)

fig.suptitle('Heterogeneity in Water Stress Coefficient across World Bank Classifications\n'
             'N = 2,429 country-year observations, 120 countries, 2000–2022',
             fontsize=11, fontweight='bold', y=1.01)
plt.savefig('/home/claude/figures/figure4_panel_overview.jpeg', dpi=300, bbox_inches='tight', format='jpeg')
plt.close()
print("Figure 4 saved")

print("\nAll figures generated successfully!")
