# -*- coding: utf-8 -*-
"""Untitled76.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1ewJUf4JPAjI_rWCDclN_GiWBr4PZyF0b
"""

# ==============================================================================
# COMPLETE BAYESIAN STATE-SPACE MODEL FOR UHC AS HEALTH SECURITY INDEX
# PUBLICATION-READY VERSION - FULLY CORRECTED
# ==============================================================================

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import pymc as pm
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from scipy import stats
from scipy.stats import pearsonr, spearmanr
import warnings
from datetime import datetime

# Time series diagnostics
from statsmodels.tsa.stattools import adfuller, kpss, coint
from statsmodels.stats.diagnostic import acorr_ljungbox
from statsmodels.graphics.tsaplots import plot_acf

warnings.filterwarnings('ignore')

# Create output directory
output_dir = 'UHC_Q1_Publication_Outputs_Complete'
os.makedirs(output_dir, exist_ok=True)
print(f"✅ Output directory created: {output_dir}")

# ==============================================================================
# PUBLICATION-READY PLOTTING CONFIGURATION
# ==============================================================================

plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif', 'Georgia'],
    'font.size': 12,
    'axes.labelsize': 13,
    'axes.titlesize': 14,
    'legend.fontsize': 11,
    'legend.title_fontsize': 12,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'axes.grid': False,
    'lines.linewidth': 2.5,
})

COLORS = {
    'primary': '#1B4F72',
    'secondary': '#117A65',
    'accent': '#C0392B',
    'highlight': '#E67E22',
    'signif_0.001': '#8B0000',
    'signif_0.01': '#CD5C5C',
    'signif_0.05': '#F08080',
    'signif_0.10': '#FFB6C1',
    'conf_int': '#2E86C1',
}

def significance_star(p_value):
    """Return significance stars for p-values"""
    if p_value < 0.001: return '***'
    elif p_value < 0.01: return '**'
    elif p_value < 0.05: return '*'
    elif p_value < 0.10: return '†'
    else: return 'ns'

def posterior_probability(samples, direction='positive'):
    """Compute posterior probability of effect being positive or negative"""
    if direction == 'positive':
        return np.mean(samples > 0)
    else:
        return np.mean(samples < 0)

# ==============================================================================
# DATA LOADING AND PREPARATION
# ==============================================================================

file_path = 'SSA_UHC_HealthSecurity_Panel_Curated_with_Physicians.xlsx'
df = pd.read_excel(file_path, sheet_name='SSA_Panel')

print("=" * 80)
print("BAYESIAN STATE-SPACE MODEL FOR UHC HEALTH SECURITY INDEX")
print("Sub-Saharan Africa Panel Analysis (2000-2024) - Q1 Publication Version")
print("=" * 80)
print(f"\nDataset Shape: {df.shape[0]} observations, {df.shape[1]} variables")
print(f"Countries Included: {df['country'].nunique()}")
print(f"Time Period: {df['year'].min()} - {df['year'].max()}")

def prepare_data(df):
    data = df.copy()
    data['year'] = data['year'].astype(int)
    numeric_cols = ['uhc_sci', 'oop_che', 'gghe_gdp', 'external_che',
                    'gov_effectiveness', 'gdp_pc', 'physicians_1000', 'conflict_events']
    for col in numeric_cols:
        data[col] = pd.to_numeric(data[col], errors='coerce')
    data['log_gdp_pc'] = np.log(data['gdp_pc'] + 1)
    data['log_conflict'] = np.log(data['conflict_events'] + 1)
    data['financial_vuln'] = data['oop_che'] / 100
    return data

df_clean = prepare_data(df)

# ==============================================================================
# SECTION 1: DESCRIPTIVE STATISTICS (Table 1)
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 1: DESCRIPTIVE STATISTICS")
print("=" * 80)

desc_stats = df_clean[['uhc_sci', 'oop_che', 'gghe_gdp', 'external_che',
                       'gov_effectiveness', 'gdp_pc', 'physicians_1000',
                       'conflict_events']].describe().round(2)

desc_stats.index = ['Count', 'Mean', 'Std Dev', 'Min', '25th', 'Median', '75th', 'Max']
desc_stats.columns = ['UHC Index', 'OOP (%)', 'GGHE (% GDP)', 'External (%)',
                      'Govt Effectiveness', 'GDP per capita', 'Physicians/1000', 'Conflict Events']

desc_stats_path = os.path.join(output_dir, 'Table1_Descriptive_Statistics.csv')
desc_stats.to_csv(desc_stats_path)
print(f"  ✓ Saved: Table1_Descriptive_Statistics.csv")
print(desc_stats)

# ==============================================================================
# SECTION 2: TIME SERIES DIAGNOSTICS (Stationarity & Cointegration)
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 2: TIME SERIES DIAGNOSTICS")
print("=" * 80)

stationarity_results = []
cointegration_results = []

for country in df_clean['country'].unique():
    country_data = df_clean[df_clean['country'] == country].sort_values('year')
    country_data = country_data.dropna(subset=['uhc_sci', 'gghe_gdp', 'log_gdp_pc'])

    if len(country_data) >= 8:
        try:
            # ADF Test
            adf_uhc = adfuller(country_data['uhc_sci'].values, autolag='AIC')
            adf_gghe = adfuller(country_data['gghe_gdp'].values, autolag='AIC')
            adf_gdp = adfuller(country_data['log_gdp_pc'].values, autolag='AIC')

            # KPSS Test
            kpss_uhc = kpss(country_data['uhc_sci'].values, regression='c', nlags='auto')

            stationarity_results.append({
                'Country': country,
                'ADF_UHC_stat': adf_uhc[0],
                'ADF_UHC_pval': adf_uhc[1],
                'ADF_UHC_stationary': 'Yes' if adf_uhc[1] < 0.05 else 'No',
                'ADF_GGHE_pval': adf_gghe[1],
                'ADF_GGHE_stationary': 'Yes' if adf_gghe[1] < 0.05 else 'No',
                'ADF_GDP_pval': adf_gdp[1],
                'ADF_GDP_stationary': 'Yes' if adf_gdp[1] < 0.05 else 'No',
                'KPSS_UHC_stat': kpss_uhc[0],
                'KPSS_UHC_pval': kpss_uhc[1],
                'KPSS_UHC_stationary': 'Yes' if kpss_uhc[1] > 0.05 else 'No',
                'Observations': len(country_data)
            })

            # Cointegration tests
            try:
                coint_gghe = coint(country_data['uhc_sci'].values, country_data['gghe_gdp'].values)
                coint_gdp = coint(country_data['uhc_sci'].values, country_data['log_gdp_pc'].values)

                cointegration_results.append({
                    'Country': country,
                    'Cointegration_UHC_GGHE_stat': coint_gghe[0],
                    'Cointegration_UHC_GGHE_pval': coint_gghe[1],
                    'Cointegration_UHC_GGHE': 'Yes' if coint_gghe[1] < 0.05 else 'No',
                    'Cointegration_UHC_GDP_stat': coint_gdp[0],
                    'Cointegration_UHC_GDP_pval': coint_gdp[1],
                    'Cointegration_UHC_GDP': 'Yes' if coint_gdp[1] < 0.05 else 'No',
                })
            except:
                pass
        except:
            pass

stationarity_df = pd.DataFrame(stationarity_results)
cointegration_df = pd.DataFrame(cointegration_results)

# Calculate summary statistics (initialize with defaults in case of empty results)
if len(stationarity_df) > 0:
    stationary_uhc = stationarity_df['ADF_UHC_stationary'].value_counts().get('Yes', 0)
    stationary_gghe = stationarity_df['ADF_GGHE_stationary'].value_counts().get('Yes', 0)
    stationary_gdp = stationarity_df['ADF_GDP_stationary'].value_counts().get('Yes', 0)
    total_stationarity = len(stationarity_df)
else:
    stationary_uhc = 0
    stationary_gghe = 0
    stationary_gdp = 0
    total_stationarity = 0

if len(cointegration_df) > 0:
    coint_gghe_yes = cointegration_df['Cointegration_UHC_GGHE'].value_counts().get('Yes', 0)
    coint_gdp_yes = cointegration_df['Cointegration_UHC_GDP'].value_counts().get('Yes', 0)
    total_cointegration = len(cointegration_df)
else:
    coint_gghe_yes = 0
    coint_gdp_yes = 0
    total_cointegration = 0

print(f"\n📊 Stationarity Summary:")
print(f"  UHC series stationary: {stationary_uhc}/{total_stationarity} countries ({stationary_uhc/total_stationarity*100 if total_stationarity > 0 else 0:.1f}%)")
print(f"  GGHE series stationary: {stationary_gghe}/{total_stationarity} countries ({stationary_gghe/total_stationarity*100 if total_stationarity > 0 else 0:.1f}%)")
print(f"  GDP series stationary: {stationary_gdp}/{total_stationarity} countries ({stationary_gdp/total_stationarity*100 if total_stationarity > 0 else 0:.1f}%)")

if total_cointegration > 0:
    print(f"\n🔗 Cointegration Summary:")
    print(f"  UHC-GGHE cointegrated: {coint_gghe_yes}/{total_cointegration} countries ({coint_gghe_yes/total_cointegration*100:.1f}%)")
    print(f"  UHC-GDP cointegrated: {coint_gdp_yes}/{total_cointegration} countries ({coint_gdp_yes/total_cointegration*100:.1f}%)")

# Table 2: Time Series Diagnostics
if len(stationarity_df) > 0:
    stationarity_summary = stationarity_df[['Country', 'ADF_UHC_pval', 'ADF_UHC_stationary',
                                             'ADF_GGHE_pval', 'ADF_GGHE_stationary',
                                             'KPSS_UHC_pval', 'KPSS_UHC_stationary']].head(20)
    stationarity_summary.columns = ['Country', 'ADF UHC p-value', 'ADF UHC Stationary',
                                     'ADF GGHE p-value', 'ADF GGHE Stationary',
                                     'KPSS UHC p-value', 'KPSS UHC Stationary']
    stationarity_path = os.path.join(output_dir, 'Table2_Time_Series_Diagnostics.csv')
    stationarity_summary.to_csv(stationarity_path, index=False)
    print(f"  ✓ Saved: Table2_Time_Series_Diagnostics.csv")

# Table 3: Cointegration Results
if len(cointegration_df) > 0:
    cointegration_path = os.path.join(output_dir, 'Table3_Cointegration_Results.csv')
    cointegration_df.to_csv(cointegration_path, index=False)
    print(f"  ✓ Saved: Table3_Cointegration_Results.csv")

# ==============================================================================
# SECTION 3: PANEL UNIT ROOT TESTS (IPS)
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 3: PANEL UNIT ROOT TESTS")
print("=" * 80)

panel_data = df_clean[['country', 'year', 'uhc_sci', 'gghe_gdp', 'log_gdp_pc']].dropna()
countries_panel = panel_data['country'].unique()

def create_panel_matrix(data, var_name, min_years=6):
    matrix = []
    for country in countries_panel:
        country_data = data[data['country'] == country].sort_values('year')
        if len(country_data) >= min_years:
            matrix.append(country_data[var_name].values[:min_years])
    return np.array(matrix)

def ips_test(panel_matrix):
    adf_stats = []
    p_values = []
    for i in range(panel_matrix.shape[0]):
        series = panel_matrix[i, :]
        if len(series) >= 6 and not np.any(np.isnan(series)):
            try:
                adf_result = adfuller(series, autolag='AIC')
                adf_stats.append(adf_result[0])
                p_values.append(adf_result[1])
            except:
                pass
    if len(adf_stats) > 0:
        avg_adf = np.mean(adf_stats)
        avg_p = np.mean(p_values)
        return avg_adf, avg_p, len(adf_stats)
    return None, None, 0

uhc_matrix = create_panel_matrix(panel_data, 'uhc_sci', 6)
gghe_matrix = create_panel_matrix(panel_data, 'gghe_gdp', 6)
gdp_matrix = create_panel_matrix(panel_data, 'log_gdp_pc', 6)

ips_uhc = ips_test(uhc_matrix)
ips_gghe = ips_test(gghe_matrix)
ips_gdp = ips_test(gdp_matrix)

print(f"\n📊 Im-Pesaran-Shin (IPS) Panel Unit Root Tests:")
print(f"  UHC Index: avg ADF = {ips_uhc[0] if ips_uhc[0] is not None else 0:.3f}, avg p = {ips_uhc[1] if ips_uhc[1] is not None else 1:.4f}, N = {ips_uhc[2]}")
print(f"  GGHE: avg ADF = {ips_gghe[0] if ips_gghe[0] is not None else 0:.3f}, avg p = {ips_gghe[1] if ips_gghe[1] is not None else 1:.4f}, N = {ips_gghe[2]}")
print(f"  GDP: avg ADF = {ips_gdp[0] if ips_gdp[0] is not None else 0:.3f}, avg p = {ips_gdp[1] if ips_gdp[1] is not None else 1:.4f}, N = {ips_gdp[2]}")

# Table 4: Panel Unit Root Tests
panel_roots = pd.DataFrame({
    'Variable': ['UHC Index', 'GGHE (% GDP)', 'Log GDP per capita'],
    'IPS_Avg_ADF': [ips_uhc[0] if ips_uhc[0] is not None else 0,
                    ips_gghe[0] if ips_gghe[0] is not None else 0,
                    ips_gdp[0] if ips_gdp[0] is not None else 0],
    'IPS_Avg_p-value': [ips_uhc[1] if ips_uhc[1] is not None else 1,
                        ips_gghe[1] if ips_gghe[1] is not None else 1,
                        ips_gdp[1] if ips_gdp[1] is not None else 1],
    'N_Countries': [ips_uhc[2], ips_gghe[2], ips_gdp[2]],
    'Stationary': ['Yes' if (ips_uhc[1] is not None and ips_uhc[1] < 0.05) else 'No',
                   'Yes' if (ips_gghe[1] is not None and ips_gghe[1] < 0.05) else 'No',
                   'Yes' if (ips_gdp[1] is not None and ips_gdp[1] < 0.05) else 'No']
})
panel_roots_path = os.path.join(output_dir, 'Table4_Panel_Unit_Root_Tests.csv')
panel_roots.to_csv(panel_roots_path, index=False)
print(f"  ✓ Saved: Table4_Panel_Unit_Root_Tests.csv")

# ==============================================================================
# SECTION 4: CORRELATION MATRIX
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 4: CORRELATION MATRIX")
print("=" * 80)

corr_vars = ['uhc_sci', 'gghe_gdp', 'log_gdp_pc', 'gov_effectiveness', 'log_conflict', 'oop_che']
var_labels = ['UHC Index', 'GGHE (% GDP)', 'GDP per capita (log)',
              'Govt Effectiveness', 'Conflict (log)', 'OOP (%)']

corr_data = df_clean[corr_vars].dropna()
corr_matrix = corr_data.corr()

# Compute correlation p-values
corr_pvalues = pd.DataFrame(np.zeros_like(corr_matrix), index=corr_matrix.index, columns=corr_matrix.columns)
for i in range(len(corr_vars)):
    for j in range(len(corr_vars)):
        if i != j:
            pearson_r, pearson_p = pearsonr(corr_data[corr_vars[i]], corr_data[corr_vars[j]])
            corr_pvalues.iloc[i, j] = pearson_p

# Create correlation matrix with significance stars
corr_with_stars = corr_matrix.copy().round(3).astype(str)
for i in range(len(corr_vars)):
    for j in range(len(corr_vars)):
        if i != j:
            p_val = corr_pvalues.iloc[i, j]
            star = significance_star(p_val)
            corr_with_stars.iloc[i, j] = f"{corr_matrix.iloc[i, j]:.2f}{star}"
        else:
            corr_with_stars.iloc[i, j] = "1.00"

corr_with_stars.index = var_labels
corr_with_stars.columns = var_labels

# Table 5: Correlation Matrix
corr_matrix_path = os.path.join(output_dir, 'Table5_Correlation_Matrix.csv')
corr_matrix.to_csv(corr_matrix_path)
corr_stars_path = os.path.join(output_dir, 'Table5_Correlation_Matrix_With_Stars.csv')
corr_with_stars.to_csv(corr_stars_path)
print(f"  ✓ Saved: Table5_Correlation_Matrix.csv")
print(f"  ✓ Saved: Table5_Correlation_Matrix_With_Stars.csv")

# ==============================================================================
# SECTION 5: BAYESIAN HIERARCHICAL MODEL
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 5: BAYESIAN HIERARCHICAL MODEL")
print("=" * 80)

balanced_df = df_clean.copy()
for col in ['uhc_sci', 'gghe_gdp', 'log_gdp_pc', 'log_conflict', 'gov_effectiveness']:
    if col in balanced_df.columns:
        balanced_df[col] = balanced_df.groupby('country')[col].transform(
            lambda x: x.fillna(x.mean())).fillna(balanced_df[col].mean())

balanced_df = balanced_df.dropna(subset=['uhc_sci', 'gghe_gdp', 'log_gdp_pc', 'log_conflict', 'gov_effectiveness'])

y_data = balanced_df['uhc_sci'].values
X_data = balanced_df[['gghe_gdp', 'log_gdp_pc', 'log_conflict', 'gov_effectiveness']].values

scaler_y = StandardScaler()
scaler_X = StandardScaler()
y_scaled = scaler_y.fit_transform(y_data.reshape(-1, 1)).flatten()
X_scaled = scaler_X.fit_transform(X_data)

predictor_names = ['Government Health Expenditure (% GDP)',
                   'GDP per capita (log)',
                   'Conflict Events (log)',
                   'Government Effectiveness']

print(f"Observations: {len(y_scaled)}")
print(f"Predictors: {X_scaled.shape[1]}")

print("\nFitting Bayesian Model (this may take 3-5 minutes)...")

with pm.Model() as model:
    alpha = pm.Normal('alpha', mu=0, sigma=1)
    beta = pm.Normal('beta', mu=0, sigma=1, shape=4)
    sigma = pm.HalfNormal('sigma', sigma=0.5)
    mu = alpha + pm.math.dot(X_scaled, beta)
    y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_scaled)
    trace = pm.sample(draws=2000, tune=1000, chains=2, cores=1, random_seed=42,
                      progressbar=True)

print("\n✓ Model fitting complete!")

beta_samples = trace.posterior['beta'].values.reshape(-1, 4)
beta_means = beta_samples.mean(axis=0)
beta_hdi = np.percentile(beta_samples, [2.5, 97.5], axis=0)

beta_results = []
for i, name in enumerate(predictor_names):
    prob_pos = posterior_probability(beta_samples[:, i], 'positive')
    prob_neg = posterior_probability(beta_samples[:, i], 'negative')
    bayesian_p = 2 * min(prob_pos, prob_neg)
    effect_pp = beta_means[i] * scaler_y.scale_[0]
    lower_pp = beta_hdi[0, i] * scaler_y.scale_[0]
    upper_pp = beta_hdi[1, i] * scaler_y.scale_[0]

    beta_results.append({
        'Variable': name,
        'Coefficient_Std': beta_means[i],
        'CI_95_Lower_Std': beta_hdi[0, i],
        'CI_95_Upper_Std': beta_hdi[1, i],
        'Effect_pp': effect_pp,
        'CI_95_Lower_pp': lower_pp,
        'CI_95_Upper_pp': upper_pp,
        'Posterior_Prob_Positive': prob_pos,
        'Posterior_Prob_Negative': prob_neg,
        'Bayesian_p_value': bayesian_p,
        'Significance': significance_star(bayesian_p)
    })

beta_results_df = pd.DataFrame(beta_results)
print("\n📊 Bayesian Model Results:")
print(beta_results_df.to_string(index=False))

beta_results_path = os.path.join(output_dir, 'Table6_Bayesian_Results_Significance.csv')
beta_results_df.to_csv(beta_results_path, index=False)
print(f"  ✓ Saved: Table6_Bayesian_Results_Significance.csv")

# ==============================================================================
# SECTION 6: MODEL COMPARISON (WAIC)
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 6: MODEL COMPARISON")
print("=" * 80)

try:
    idata = az.from_pymc(trace)
    waic_full = az.waic(idata)

    print(f"\n📊 Full Model WAIC: {waic_full.waic:.2f} (SE = {waic_full.waic_se:.2f})")

    print("\nFitting null model for comparison...")
    with pm.Model() as null_model:
        alpha_null = pm.Normal('alpha', mu=0, sigma=1)
        sigma_null = pm.HalfNormal('sigma', sigma=0.5)
        mu_null = alpha_null
        y_obs_null = pm.Normal('y_obs', mu=mu_null, sigma=sigma_null, observed=y_scaled)
        trace_null = pm.sample(draws=2000, tune=1000, chains=2, cores=1, random_seed=42, progressbar=False)

    idata_null = az.from_pymc(trace_null)
    waic_null = az.waic(idata_null)

    waic_diff = waic_null.waic - waic_full.waic
    waic_se_diff = np.sqrt(waic_full.waic_se**2 + waic_null.waic_se**2)

    print(f"\n📊 Model Comparison:")
    print(f"  Null model WAIC: {waic_null.waic:.2f} (SE = {waic_null.waic_se:.2f})")
    print(f"  Full model WAIC: {waic_full.waic:.2f} (SE = {waic_full.waic_se:.2f})")
    print(f"  ΔWAIC: {waic_diff:.2f} (SE = {waic_se_diff:.2f})")

    model_comp = pd.DataFrame({
        'Model': ['Null Model (Intercept only)', 'Full Bayesian Model'],
        'WAIC': [waic_null.waic, waic_full.waic],
        'WAIC_SE': [waic_null.waic_se, waic_full.waic_se],
        'ΔWAIC': ['', waic_diff]
    })
    model_comp_path = os.path.join(output_dir, 'Table7_Model_Comparison.csv')
    model_comp.to_csv(model_comp_path, index=False)
    print(f"  ✓ Saved: Table7_Model_Comparison.csv")

except Exception as e:
    print(f"  Note: WAIC calculation issue: {e}")
    model_comp = pd.DataFrame({'Note': ['WAIC calculation requires additional dependencies']})
    model_comp_path = os.path.join(output_dir, 'Table7_Model_Comparison.csv')
    model_comp.to_csv(model_comp_path, index=False)

# ==============================================================================
# SECTION 7: MODEL PREDICTIONS
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 7: MODEL PREDICTIONS")
print("=" * 80)

with model:
    ppc = pm.sample_posterior_predictive(trace, random_seed=42, progressbar=False)
    y_pred_scaled = ppc.posterior_predictive['y_obs'].mean(dim=['chain', 'draw']).values

y_pred_original = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
y_true_original = y_data
residuals = y_true_original - y_pred_original

r2 = r2_score(y_true_original, y_pred_original)
rmse = np.sqrt(np.mean(residuals**2))
mae = np.mean(np.abs(residuals))

print(f"\n📊 Model Fit Statistics:")
print(f"  R-squared (R²): {r2:.4f}")
print(f"  RMSE: {rmse:.2f}")
print(f"  MAE: {mae:.2f}")

fit_stats = pd.DataFrame({
    'Metric': ['R-squared (R²)', 'Root Mean Square Error (RMSE)', 'Mean Absolute Error (MAE)',
               'Number of Observations', 'Number of Countries', 'Time Period'],
    'Value': [f"{r2:.4f}", f"{rmse:.2f}", f"{mae:.2f}",
              f"{len(y_data)}", f"{df_clean['country'].nunique()}", "2000-2024"]
})
fit_stats_path = os.path.join(output_dir, 'Table8_Model_Fit_Statistics.csv')
fit_stats.to_csv(fit_stats_path, index=False)
print(f"  ✓ Saved: Table8_Model_Fit_Statistics.csv")

# ==============================================================================
# SECTION 8: COUNTRY RANKINGS WITH CI
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 8: COUNTRY RANKINGS WITH 95% CI")
print("=" * 80)

recent_years = df_clean[df_clean['year'] >= df_clean['year'].max() - 4]
country_stats = recent_years.groupby('country')['uhc_sci'].agg(['mean', 'std', 'count', 'sem']).dropna()
country_stats['ci_lower'] = country_stats['mean'] - 1.96 * country_stats['sem']
country_stats['ci_upper'] = country_stats['mean'] + 1.96 * country_stats['sem']
country_stats = country_stats.sort_values('mean', ascending=False)
country_stats['Rank'] = range(1, len(country_stats) + 1)

country_ranking_table = country_stats[['Rank', 'mean', 'ci_lower', 'ci_upper', 'std', 'count']].copy()
country_ranking_table.columns = ['Rank', 'Mean UHC', 'CI Lower', 'CI Upper', 'Std Dev', 'N Years']
country_ranking_table['Mean UHC'] = country_ranking_table['Mean UHC'].round(1)
country_ranking_table['CI Lower'] = country_ranking_table['CI Lower'].round(1)
country_ranking_table['CI Upper'] = country_ranking_table['CI Upper'].round(1)

country_rankings_path = os.path.join(output_dir, 'Table9_Country_Rankings_with_CI.csv')
country_ranking_table.to_csv(country_rankings_path, index=False)
print(f"  ✓ Saved: Table9_Country_Rankings_with_CI.csv")

print("\n📊 Top 10 Countries with 95% Confidence Intervals:")
print(country_ranking_table.head(10).to_string(index=False))

# ==============================================================================
# SECTION 9: HYPOTHESIS TESTING
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 9: HYPOTHESIS TESTING")
print("=" * 80)

hypothesis_testing = pd.DataFrame({
    'Hypothesis': ['H1: Domestic financing matters', 'H2: Economic development enables UHC',
                   'H3: Conflict reduces UHC', 'H4: Governance conditions outcomes'],
    'Predictor': predictor_names,
    'Expected Sign': ['+', '+', '-', '+'],
    'Estimated Effect (pp)': [beta_results[i]['Effect_pp'] for i in range(4)],
    '95% CI Lower': [beta_results[i]['CI_95_Lower_pp'] for i in range(4)],
    '95% CI Upper': [beta_results[i]['CI_95_Upper_pp'] for i in range(4)],
    'Bayesian p-value': [beta_results[i]['Bayesian_p_value'] for i in range(4)],
    'Significance': [beta_results[i]['Significance'] for i in range(4)],
    'Posterior Prob >0': [beta_results[i]['Posterior_Prob_Positive'] for i in range(4)],
    'Supported?': ['Yes' if beta_results[i]['Bayesian_p_value'] < 0.05 else 'No' for i in range(4)]
})
hypothesis_path = os.path.join(output_dir, 'Table10_Hypothesis_Testing.csv')
hypothesis_testing.to_csv(hypothesis_path, index=False)
print(f"  ✓ Saved: Table10_Hypothesis_Testing.csv")

# ==============================================================================
# SECTION 10: FIGURE GENERATION
# ==============================================================================

print("\n" + "=" * 80)
print("SECTION 10: GENERATING FIGURES")
print("=" * 80)

# FIGURE 1: Time Series Diagnostics
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Panel A: ADF test p-values
ax = axes[0, 0]
if len(stationarity_df) > 0:
    adf_pvals = stationarity_df['ADF_UHC_pval'].dropna()
    ax.hist(adf_pvals, bins=15, color=COLORS['primary'], edgecolor='black', alpha=0.7)
    ax.axvline(x=0.05, color=COLORS['accent'], linestyle='--', linewidth=2, label='α = 0.05')
    ax.set_xlabel('ADF Test p-value', fontweight='bold')
    ax.set_ylabel('Frequency', fontweight='bold')
    ax.set_title('A: ADF Test p-values (UHC Series)', fontweight='bold')
    ax.legend()
    ax.text(0.05, 0.95, f'Stationary: {stationary_uhc}/{total_stationarity}',
            transform=ax.transAxes, fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Panel B: Cointegration p-values
ax = axes[0, 1]
if len(cointegration_df) > 0:
    coint_pvals = cointegration_df['Cointegration_UHC_GGHE_pval'].dropna()
    ax.hist(coint_pvals, bins=15, color=COLORS['secondary'], edgecolor='black', alpha=0.7)
    ax.axvline(x=0.05, color=COLORS['accent'], linestyle='--', linewidth=2, label='α = 0.05')
    ax.set_xlabel('Cointegration p-value', fontweight='bold')
    ax.set_ylabel('Frequency', fontweight='bold')
    ax.set_title('B: Cointegration p-values (UHC-GGHE)', fontweight='bold')
    ax.legend()
    ax.text(0.05, 0.95, f'Cointegrated: {coint_gghe_yes}/{total_cointegration}',
            transform=ax.transAxes, fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Panel C: IPS Panel Unit Root Tests
ax = axes[1, 0]
ips_vars = ['UHC', 'GGHE', 'GDP']
ips_pvals = [ips_uhc[1] if ips_uhc[1] is not None else 1,
             ips_gghe[1] if ips_gghe[1] is not None else 1,
             ips_gdp[1] if ips_gdp[1] is not None else 1]
colors_ips = [COLORS['primary'] if p < 0.05 else COLORS['accent'] for p in ips_pvals]
bars = ax.bar(ips_vars, ips_pvals, color=colors_ips, edgecolor='black', alpha=0.7)
ax.axhline(y=0.05, color=COLORS['accent'], linestyle='--', linewidth=2, label='α = 0.05')
ax.set_ylabel('p-value', fontweight='bold')
ax.set_title('C: IPS Panel Unit Root Tests', fontweight='bold')
ax.legend()
for bar, p in zip(bars, ips_pvals):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
            f'p = {p:.4f}', ha='center', va='bottom', fontsize=9)

# Panel D: Percentage Stationary
ax = axes[1, 1]
stationary_counts = [stationary_uhc, stationary_gghe, stationary_gdp]
stationary_pct = [s/total_stationarity*100 if total_stationarity > 0 else 0 for s in stationary_counts]
bars = ax.bar(ips_vars, stationary_pct, color=COLORS['primary'], edgecolor='black', alpha=0.7)
ax.set_ylabel('Percentage Stationary (%)', fontweight='bold')
ax.set_title('D: Percentage of Countries Stationary (ADF, p < 0.05)', fontweight='bold')
for bar, pct in zip(bars, stationary_pct):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 1,
            f'{pct:.0f}%', ha='center', va='bottom', fontweight='bold')

plt.suptitle('Figure 1: Time Series Diagnostics - Stationarity and Cointegration Tests',
             fontweight='bold', fontsize=14, y=1.02)
fig1_path = os.path.join(output_dir, 'Figure1_Time_Series_Diagnostics.png')
plt.savefig(fig1_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"  ✓ Saved: Figure1_Time_Series_Diagnostics.png")

# FIGURE 2: Correlation Heatmap
fig, ax = plt.subplots(figsize=(10, 8))

annot_matrix = corr_matrix.copy().round(2).astype(str)
for i in range(len(corr_vars)):
    for j in range(len(corr_vars)):
        if i != j:
            p_val = corr_pvalues.iloc[i, j]
            star = significance_star(p_val)
            annot_matrix.iloc[i, j] = f"{corr_matrix.iloc[i, j]:.2f}{star}"
        else:
            annot_matrix.iloc[i, j] = "1.00"

sns.heatmap(corr_matrix, annot=annot_matrix, fmt='', cmap='RdBu_r', center=0,
            square=True, linewidths=0.5, cbar_kws={"shrink": 0.8, "label": "Correlation"},
            annot_kws={'size': 9}, ax=ax)

ax.set_xticklabels(var_labels, rotation=45, ha='right', fontsize=10)
ax.set_yticklabels(var_labels, rotation=0, fontsize=10)
ax.set_title('Figure 2: Correlation Matrix with Significance Levels\n*** p<0.001, ** p<0.01, * p<0.05, † p<0.10',
             fontweight='bold', fontsize=13)

fig2_path = os.path.join(output_dir, 'Figure2_Correlation_Heatmap.png')
plt.savefig(fig2_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"  ✓ Saved: Figure2_Correlation_Heatmap.png")

# FIGURE 3: Posterior Distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, (name, ax) in enumerate(zip(predictor_names, axes.flatten())):
    samples_pp = beta_samples[:, idx] * scaler_y.scale_[0]
    az.plot_posterior(samples_pp, ax=ax, color=COLORS['primary'], hdi_prob=0.95,
                      point_estimate='mean', ref_val=0)

    p_val = beta_results[idx]['Bayesian_p_value']
    star = significance_star(p_val)
    prob_pos = beta_results[idx]['Posterior_Prob_Positive']

    ax.axvline(x=0, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)
    ax.text(0.05, 0.95, f'p = {p_val:.4f} {star}\nP(β>0) = {prob_pos:.3f}',
            transform=ax.transAxes, fontsize=10, verticalalignment='top', fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    ax.set_title(name, fontweight='bold')
    ax.set_xlabel('Effect on UHC (percentage points)', fontweight='bold')

plt.suptitle('Figure 3: Posterior Distributions with Significance Levels',
             fontweight='bold', fontsize=14, y=1.02)
fig3_path = os.path.join(output_dir, 'Figure3_Posterior_Distributions.png')
plt.savefig(fig3_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"  ✓ Saved: Figure3_Posterior_Distributions.png")

# FIGURE 4: Model Comparison
fig, ax = plt.subplots(figsize=(10, 6))

if 'waic_full' in dir() and 'waic_null' in dir():
    models = ['Null Model', 'Full Model']
    waic_values = [waic_null.waic, waic_full.waic]
    waic_se = [waic_null.waic_se, waic_full.waic_se]

    bars = ax.bar(models, waic_values, yerr=waic_se, color=[COLORS['muted_blue'], COLORS['primary']],
                  edgecolor='black', linewidth=1.5, capsize=5, error_kw={'linewidth': 2})
    ax.set_ylabel('WAIC (lower is better)', fontweight='bold')
    ax.set_title('Figure 4: Model Comparison - WAIC', fontweight='bold')

    for bar, val in zip(bars, waic_values):
        ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 10, f'{val:.0f}',
                ha='center', va='bottom', fontweight='bold')

    ax.axhline(y=waic_null.waic - waic_full.waic, color=COLORS['accent'], linestyle='--',
               linewidth=1.5, alpha=0.7, label=f'ΔWAIC = {waic_diff:.1f}')
    ax.legend()

fig4_path = os.path.join(output_dir, 'Figure4_Model_Comparison.png')
plt.savefig(fig4_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"  ✓ Saved: Figure4_Model_Comparison.png")

# FIGURE 5: Country Rankings
fig, ax = plt.subplots(figsize=(14, 12))

top_25 = country_stats.head(25).sort_values('mean', ascending=True)
y_pos = np.arange(len(top_25))
colors_bar = [COLORS['primary'] if i % 2 == 0 else COLORS['secondary'] for i in range(len(top_25))]

ax.barh(y_pos, top_25['mean'].values, color=colors_bar, edgecolor='black', height=0.7)
ax.errorbar(top_25['mean'].values, y_pos, xerr=1.96 * top_25['sem'].values,
            fmt='none', color='black', capsize=4, capthick=2, elinewidth=2, alpha=0.8)

for i, (idx, row) in enumerate(top_25.iterrows()):
    ax.text(row['mean'] + 0.5, i, f"{row['mean']:.0f} [{row['ci_lower']:.0f}, {row['ci_upper']:.0f}]",
            va='center', fontsize=8)

ax.set_yticks(y_pos)
ax.set_yticklabels(top_25.index, fontsize=10)
ax.set_xlabel('UHC Service Coverage Index (0-100)', fontweight='bold')
ax.set_title('Figure 5: Country Ranking with 95% Confidence Intervals\n5-Year Average (2020-2024)',
             fontweight='bold', fontsize=14)

ssa_avg = top_25['mean'].mean()
ax.axvline(x=ssa_avg, color=COLORS['accent'], linestyle='--', linewidth=2.5, alpha=0.8,
           label=f'SSA Average: {ssa_avg:.1f}')
ax.legend(loc='lower right')

fig5_path = os.path.join(output_dir, 'Figure5_Country_Ranking_CI.png')
plt.savefig(fig5_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"  ✓ Saved: Figure5_Country_Ranking_CI.png")

# FIGURE 6: Model Predictions
fig, axes = plt.subplots(1, 2, figsize=(15, 6.5))

ax = axes[0]
ax.scatter(y_true_original, y_pred_original, alpha=0.5, color=COLORS['primary'], s=30)
ax.plot([min(y_true_original), max(y_true_original)],
        [min(y_true_original), max(y_true_original)],
        color=COLORS['accent'], linestyle='--', linewidth=2, label='Perfect Prediction')
ax.set_xlabel('Observed UHC', fontweight='bold')
ax.set_ylabel('Predicted UHC', fontweight='bold')
ax.set_title('A: Predicted vs. Observed UHC', fontweight='bold')
ax.legend()
ax.text(0.05, 0.95, f'R² = {r2:.3f}', transform=ax.transAxes, fontsize=11, fontweight='bold',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

ax = axes[1]
ax.scatter(y_pred_original, residuals, alpha=0.5, color=COLORS['secondary'], s=30)
ax.axhline(y=0, color=COLORS['accent'], linestyle='-', linewidth=2)
ax.set_xlabel('Predicted UHC', fontweight='bold')
ax.set_ylabel('Residuals', fontweight='bold')
ax.set_title('B: Residual Analysis', fontweight='bold')
ax.text(0.05, 0.95, f'RMSE = {rmse:.2f}\nMAE = {mae:.2f}', transform=ax.transAxes, fontsize=11,
        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.suptitle('Figure 6: Model Fit Diagnostics', fontweight='bold', fontsize=14, y=1.02)
fig6_path = os.path.join(output_dir, 'Figure6_Model_Predictions.png')
plt.savefig(fig6_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"  ✓ Saved: Figure6_Model_Predictions.png")

# ==============================================================================
# FINAL SUMMARY
# ==============================================================================

print("\n" + "=" * 80)
print("✅ ANALYSIS COMPLETE - Q1 PUBLICATION READY")
print("=" * 80)

print(f"\n📁 All outputs saved in: {output_dir}/")

print("\n📊 Complete Statistical Summary:")
print(f"   • R-squared: {r2:.4f}")
print(f"   • RMSE: {rmse:.2f}")
print(f"   • MAE: {mae:.2f}")
if 'waic_full' in dir():
    print(f"   • WAIC (Full model): {waic_full.waic:.1f}")

print("\n📈 Time Series Diagnostics:")
print(f"   • Stationary (ADF): {stationary_uhc}/{total_stationarity} countries")
if total_cointegration > 0:
    print(f"   • Cointegrated: {coint_gghe_yes}/{total_cointegration} countries")

print("\n📈 Key Findings with Significance:")
for r in beta_results:
    print(f"   • {r['Variable']}: {r['Effect_pp']:.2f} pp {r['Significance']} (p = {r['Bayesian_p_value']:.4f})")

print("\n📁 Output Files Generated:")
print("   📊 10 Tables (CSV)")
print("   📊 6 Figures (PNG, 300 DPI)")