{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Supplementary Script S3: Random Forest Driver Analysis\n",
    "\n",
    "**Study:** Assessing Drivers of Forest Loss as Indicators of Habitat Degradation in the Greater Kafue Ecosystem of Zambia Using a Random Forest Approach  \n",
    "**Authors:** Gift Mulenga, Darius Phiri, Matamyo Simwanda, Vincent R. Nyirenda \n",
    "\n",
    "---\n",
    "\n",
    "## Purpose\n",
    "\n",
    "This script trains the Random Forest (RF) model, evaluates its performance, quantifies variable importance, identifies critical thresholds using Youden's J statistic, generates partial dependence plots, produces a representative decision tree, and saves the trained model for use in susceptibility mapping.\n",
    "\n",
    "## Inputs Required\n",
    "\n",
    "| Input | Source |\n",
    "|-------|--------|\n",
    "| `train_data.csv` | Output of Script S2 (`samples/` subdirectory) |\n",
    "| `test_data.csv` | Output of Script S2 |\n",
    "\n",
    "## Outputs Produced\n",
    "\n",
    "**Figures (manuscript):**\n",
    "\n",
    "| File | Manuscript Figure |\n",
    "|------|-----------------|\n",
    "| `Figure_ROC.png` | Figure 3: ROC curve (AUC = 0.828) |\n",
    "| `Figure_CV_Stability.png` | Figure 4: Cross-validation stability |\n",
    "| `Figure_Importance.png` / `Figure_Category_Importance.png` | Figure 5: Variable importance (MDI + permutation) |\n",
    "| `Figure_Univariate_ROC.png` | Figure 6: Univariate ROC curves with Youden's J thresholds |\n",
    "| `Figure_PDP.png` | Figure 7: Partial dependence plots for top 6 drivers |\n",
    "| `Figure_Decision_Tree.png` | Figure 8: Representative decision tree (depth limited to 4 levels) |\n",
    "\n",
    "**Tables:**\n",
    "\n",
    "| File | Content |\n",
    "|------|---------|\n",
    "| `Table_5_46_Confusion_Matrix.csv` | Table 3: Confusion matrix (manuscript) |\n",
    "| `Table_5_51_Thresholds.csv` | Table 4: Critical thresholds (manuscript) |\n",
    "| `Table_5_53_Validation_Summary.csv` | Supplementary Table S1: Model performance and validation summary |\n",
    "| `Table_5_48_Variable_Importance.csv` | Supplementary Table S2: Variable importance rankings |\n",
    "\n",
    "**Model:**\n",
    "\n",
    "| File | Content |\n",
    "|------|---------|\n",
    "| `rf_model.joblib` | Trained RF model (input to Script S4) |\n",
    "| `model_features.txt` | Ordered list of features used |\n",
    "| `model_metadata.json` | Full configuration record for reproducibility |\n",
    "\n",
    "## How to Use\n",
    "\n",
    "> **⚠ USER ACTION REQUIRED:** Update `Config.BASE_DIR` before running.\n",
    "\n",
    "1. Ensure Script S2 has been completed and `train_data.csv` / `test_data.csv` exist\n",
    "2. Update `Config.BASE_DIR` to match your output directory from Script S2\n",
    "3. Run all cells in order\n",
    "4. The trained model (`rf_model.joblib`) is automatically saved for use by Script S4\n",
    "\n",
    "## Model Configuration (Methods Alignment)\n",
    "\n",
    "All parameters match Section 2.5 (Random Forest Model Configuration) of the manuscript:\n",
    "\n",
    "| Parameter | Value | Rationale |\n",
    "|-----------|-------|-----------|\n",
    "| `n_estimators` | 500 | Sufficient for stable importance estimates |\n",
    "| `max_features` | `'sqrt'` | Standard RF classification practice |\n",
    "| `class_weight` | `'balanced'` | Compensates for class imbalance |\n",
    "| `min_samples_split` | 10 | Prevents overfitting on small nodes |\n",
    "| `min_samples_leaf` | 5 | Prevents overfitting on small leaf nodes |\n",
    "| `oob_score` | True | Internal generalisation estimate |\n",
    "\n",
    "## Scientific Notes\n",
    "\n",
    "- **Permutation importance** is calculated on the *independent test set* (not the training set) to avoid overestimation due to memorisation effects.\n",
    "- **Spatial cross-validation** uses `GroupKFold` with spatial block IDs from Script S2, ensuring that geographically nearby samples are never split across train/validation folds within CV.\n",
    "- **Youden's J statistic** (Equation 4): J = TPR − FPR = Sensitivity + Specificity − 1. For each predictor, the optimal threshold is the value maximising J on the training set, computed from a univariate ROC curve. This is a ROC-based method, not a simple mean-splitting approach.\n",
    "- **Variable importance ranking** uses `pandas.rank(ascending=False)`, so rank 1 = highest importance. This corrects an earlier implementation that used array indices rather than importance ranks.\n",
    "- **RFE (n=10)** was run as an exploratory diagnostic; the full 14-predictor set was retained for the main analysis to preserve all theoretically motivated variables. RFE results are saved to Supplementary Table S2.\n",
    "\n",
    "## Dependencies\n",
    "\n",
    "```\n",
    "scikit-learn >= 1.1.0\n",
    "numpy >= 1.23.0\n",
    "pandas >= 1.5.0\n",
    "matplotlib >= 3.6.0\n",
    "seaborn >= 0.12.0\n",
    "joblib >= 1.2.0\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76b1a25b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import cross_validate, StratifiedKFold, GroupKFold\n",
    "from sklearn.feature_selection import RFE\n",
    "from sklearn.inspection import permutation_importance, PartialDependenceDisplay\n",
    "from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,\n",
    "                             roc_auc_score, roc_curve, auc, confusion_matrix, cohen_kappa_score)\n",
    "from sklearn.tree import export_text, plot_tree\n",
    "import joblib\n",
    "import json\n",
    "from datetime import datetime\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Set publication-quality plot defaults\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "plt.rcParams['savefig.dpi'] = 300\n",
    "plt.rcParams['font.family'] = 'sans-serif'\n",
    "plt.rcParams['font.size'] = 10\n",
    "plt.rcParams['axes.titlesize'] = 12\n",
    "plt.rcParams['axes.labelsize'] = 11\n",
    "plt.rcParams['xtick.labelsize'] = 9\n",
    "plt.rcParams['ytick.labelsize'] = 9\n",
    "plt.rcParams['legend.fontsize'] = 9\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e15b55e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 1. CONFIGURATION\n",
    "# =============================================================================\n",
    "\n",
    "class Config:\n",
    "    # UPDATE THESE PATHS TO MATCH YOUR SYSTEM\n",
    "    BASE_DIR = r\"D:/Publication/Habitat Loss in the Greater Kafue Ecosystem/Thesis_anlysis/Final_results for defence/outputs/objective_b\"\n",
    "    INPUT_DIR = os.path.join(BASE_DIR, \"samples\")\n",
    "    OUTPUT_DIR = os.path.join(BASE_DIR, \"results\")\n",
    "    TABLES_DIR = os.path.join(BASE_DIR, \"tables\")\n",
    "    FIGURES_DIR = os.path.join(BASE_DIR, \"figures\")\n",
    "    \n",
    "    # Model parameters\n",
    "    N_ESTIMATORS = 500\n",
    "    MAX_DEPTH = None\n",
    "    MIN_SAMPLES_SPLIT = 10\n",
    "    MIN_SAMPLES_LEAF = 5\n",
    "    MAX_FEATURES = 'sqrt'\n",
    "    CLASS_WEIGHT = 'balanced'\n",
    "    RANDOM_STATE = 42\n",
    "    N_JOBS = -1\n",
    "    \n",
    "    # Cross-validation\n",
    "    CV_FOLDS = 10\n",
    "    USE_SPATIAL_CV = True\n",
    "    \n",
    "    # Feature selection\n",
    "    USE_RFE = True\n",
    "    RFE_N_FEATURES = 10\n",
    "    \n",
    "    # Variables to exclude (from VIF analysis)\n",
    "    VARS_TO_EXCLUDE = ['years_protected']\n",
    "    \n",
    "    # Variable categories\n",
    "    VARIABLE_CATEGORIES = {\n",
    "        'dist_roads': 'Proximity', 'dist_settlements': 'Proximity',\n",
    "        'dist_rivers': 'Proximity', 'dist_knp': 'Proximity',\n",
    "        'pop_density': 'Socio-economic', 'pop_change': 'Socio-economic',\n",
    "        'pct_cultivated': 'Socio-economic', 'protection_status': 'Conservation',\n",
    "        'elevation': 'Topographic', 'slope': 'Topographic',\n",
    "        'aspect': 'Topographic', 'twi': 'Topographic',\n",
    "        'mean_rainfall': 'Climatic', 'mean_temp': 'Climatic',\n",
    "    }\n",
    "    \n",
    "    # Display names for variables\n",
    "    VARIABLE_DISPLAY_NAMES = {\n",
    "        'dist_roads': 'Distance to Roads',\n",
    "        'dist_settlements': 'Distance to Settlements',\n",
    "        'dist_rivers': 'Distance to Rivers',\n",
    "        'dist_knp': 'Distance to KNP',\n",
    "        'pop_density': 'Population Density',\n",
    "        'pop_change': 'Population Change',\n",
    "        'pct_cultivated': 'Percent Cultivated',\n",
    "        'protection_status': 'Protection Status',\n",
    "        'elevation': 'Elevation',\n",
    "        'slope': 'Slope',\n",
    "        'aspect': 'Aspect',\n",
    "        'twi': 'Topographic Wetness Index',\n",
    "        'mean_rainfall': 'Mean Rainfall',\n",
    "        'mean_temp': 'Mean Temperature',\n",
    "    }\n",
    "    \n",
    "    # Category colors\n",
    "    CATEGORY_COLORS = {\n",
    "        'Socio-economic': '#E74C3C',\n",
    "        'Climatic': '#3498DB',\n",
    "        'Proximity': '#2ECC71',\n",
    "        'Topographic': '#9B59B6',\n",
    "        'Conservation': '#F39C12',\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "716a5947",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 2. DATA LOADING\n",
    "# =============================================================================\n",
    "\n",
    "def create_output_directories():\n",
    "    for d in [Config.OUTPUT_DIR, Config.TABLES_DIR, Config.FIGURES_DIR]:\n",
    "        os.makedirs(d, exist_ok=True)\n",
    "    print(\"✓ Output directories ready\")\n",
    "\n",
    "\n",
    "def load_data():\n",
    "    print(\"\\n--- Loading Data ---\")\n",
    "    train_df = pd.read_csv(os.path.join(Config.INPUT_DIR, 'train_data.csv'))\n",
    "    test_df = pd.read_csv(os.path.join(Config.INPUT_DIR, 'test_data.csv'))\n",
    "    \n",
    "    print(f\"  Train samples: {len(train_df)}\")\n",
    "    print(f\"  Test samples: {len(test_df)}\")\n",
    "    \n",
    "    # Identify features\n",
    "    exclude_cols = ['x', 'y', 'habitat_loss', 'row', 'col', 'spatial_block'] + Config.VARS_TO_EXCLUDE\n",
    "    features = [c for c in train_df.columns if c not in exclude_cols]\n",
    "    \n",
    "    print(f\"  Features: {len(features)}\")\n",
    "    print(f\"  Excluded: {Config.VARS_TO_EXCLUDE}\")\n",
    "    \n",
    "    return train_df, test_df, features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "788ab8c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 3. FEATURE SELECTION WITH RFE\n",
    "# =============================================================================\n",
    "\n",
    "def perform_rfe(X_train, y_train, features, n_features_to_select=10):\n",
    "    print(\"\\n--- Recursive Feature Elimination ---\")\n",
    "    \n",
    "    rf_rfe = RandomForestClassifier(\n",
    "        n_estimators=100, max_depth=10, random_state=Config.RANDOM_STATE, n_jobs=Config.N_JOBS\n",
    "    )\n",
    "    \n",
    "    rfe = RFE(estimator=rf_rfe, n_features_to_select=n_features_to_select, step=1)\n",
    "    rfe.fit(X_train, y_train)\n",
    "    \n",
    "    selected_features = [f for f, s in zip(features, rfe.support_) if s]\n",
    "    \n",
    "    print(f\"  Selected features ({len(selected_features)}):\")\n",
    "    for f in selected_features:\n",
    "        print(f\"    - {f}\")\n",
    "    \n",
    "    rfe_ranking = pd.DataFrame({\n",
    "        'Feature': features,\n",
    "        'Rank': rfe.ranking_,\n",
    "        'Selected': rfe.support_\n",
    "    }).sort_values('Rank')\n",
    "    \n",
    "    return selected_features, rfe_ranking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ed1a3fa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 4. MODEL TRAINING\n",
    "# =============================================================================\n",
    "\n",
    "def train_model(X_train, y_train):\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"TRAINING RANDOM FOREST MODEL\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    model = RandomForestClassifier(\n",
    "        n_estimators=Config.N_ESTIMATORS,\n",
    "        max_depth=Config.MAX_DEPTH,\n",
    "        min_samples_split=Config.MIN_SAMPLES_SPLIT,\n",
    "        min_samples_leaf=Config.MIN_SAMPLES_LEAF,\n",
    "        max_features=Config.MAX_FEATURES,\n",
    "        class_weight=Config.CLASS_WEIGHT,\n",
    "        oob_score=True,\n",
    "        random_state=Config.RANDOM_STATE,\n",
    "        n_jobs=Config.N_JOBS\n",
    "    )\n",
    "    \n",
    "    print(f\"\\n  Training with {len(y_train)} samples...\")\n",
    "    model.fit(X_train, y_train)\n",
    "    \n",
    "    print(f\"\\n  ✓ Training complete\")\n",
    "    print(f\"    Training Accuracy: {model.score(X_train, y_train):.4f}\")\n",
    "    print(f\"    OOB Score: {model.oob_score_:.4f}\")\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0cafe8fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 5. SPATIAL CROSS-VALIDATION\n",
    "# =============================================================================\n",
    "\n",
    "def spatial_cross_validation(model, X, y, groups, n_splits=10):\n",
    "    print(\"\\n--- Spatial Cross-Validation ---\")\n",
    "    \n",
    "    n_unique_groups = len(np.unique(groups))\n",
    "    actual_splits = min(n_splits, n_unique_groups)\n",
    "    group_kfold = GroupKFold(n_splits=actual_splits)\n",
    "    \n",
    "    scores = {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'roc_auc': []}\n",
    "    \n",
    "    for fold, (train_idx, val_idx) in enumerate(group_kfold.split(X, y, groups)):\n",
    "        X_train_cv, X_val_cv = X[train_idx], X[val_idx]\n",
    "        y_train_cv, y_val_cv = y[train_idx], y[val_idx]\n",
    "        \n",
    "        fold_model = RandomForestClassifier(\n",
    "            n_estimators=Config.N_ESTIMATORS,\n",
    "            max_depth=Config.MAX_DEPTH,\n",
    "            min_samples_split=Config.MIN_SAMPLES_SPLIT,\n",
    "            min_samples_leaf=Config.MIN_SAMPLES_LEAF,\n",
    "            max_features=Config.MAX_FEATURES,\n",
    "            class_weight=Config.CLASS_WEIGHT,\n",
    "            random_state=Config.RANDOM_STATE,\n",
    "            n_jobs=Config.N_JOBS\n",
    "        )\n",
    "        fold_model.fit(X_train_cv, y_train_cv)\n",
    "        \n",
    "        y_pred = fold_model.predict(X_val_cv)\n",
    "        y_prob = fold_model.predict_proba(X_val_cv)[:, 1]\n",
    "        \n",
    "        scores['accuracy'].append(accuracy_score(y_val_cv, y_pred))\n",
    "        scores['precision'].append(precision_score(y_val_cv, y_pred, zero_division=0))\n",
    "        scores['recall'].append(recall_score(y_val_cv, y_pred, zero_division=0))\n",
    "        scores['f1'].append(f1_score(y_val_cv, y_pred, zero_division=0))\n",
    "        scores['roc_auc'].append(roc_auc_score(y_val_cv, y_prob))\n",
    "    \n",
    "    print(f\"  Spatial CV Results ({n_unique_groups} blocks, {actual_splits} folds):\")\n",
    "    for metric, vals in scores.items():\n",
    "        print(f\"    {metric}: {np.mean(vals):.4f} ± {np.std(vals):.4f}\")\n",
    "    \n",
    "    return scores\n",
    "\n",
    "\n",
    "def standard_cross_validation(model, X, y, n_splits=10):\n",
    "    print(\"\\n--- Standard Cross-Validation ---\")\n",
    "    \n",
    "    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=Config.RANDOM_STATE)\n",
    "    scoring = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']\n",
    "    \n",
    "    cv_results = cross_validate(model, X, y, cv=cv, scoring=scoring, return_train_score=False)\n",
    "    \n",
    "    scores = {metric: cv_results[f'test_{metric}'] for metric in scoring}\n",
    "    \n",
    "    print(f\"  CV Results ({n_splits}-fold):\")\n",
    "    for metric, vals in scores.items():\n",
    "        print(f\"    {metric}: {np.mean(vals):.4f} ± {np.std(vals):.4f}\")\n",
    "    \n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e4be717d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 6. CV STABILITY PLOT (manuscript Figure 4)\n",
    "# =============================================================================\n",
    "\n",
    "def create_cv_stability_plot(cv_scores):\n",
    "    print(\"\\n--- Creating CV Stability Plot (Figure 3) ---\")\n",
    "    \n",
    "    metric_order = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']\n",
    "    metric_labels = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC']\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(10, 6))\n",
    "    \n",
    "    colors = ['#3498DB', '#E74C3C', '#2ECC71', '#9B59B6', '#F39C12']\n",
    "    bp = ax.boxplot([cv_scores[m] for m in metric_order], \n",
    "                    labels=metric_labels,\n",
    "                    patch_artist=True,\n",
    "                    widths=0.6)\n",
    "    \n",
    "    for patch, color in zip(bp['boxes'], colors):\n",
    "        patch.set_facecolor(color)\n",
    "        patch.set_alpha(0.7)\n",
    "    \n",
    "    for i, metric in enumerate(metric_order):\n",
    "        x = np.random.normal(i + 1, 0.04, size=len(cv_scores[metric]))\n",
    "        ax.scatter(x, cv_scores[metric], alpha=0.6, color='black', s=20, zorder=3)\n",
    "    \n",
    "    for i, metric in enumerate(metric_order):\n",
    "        mean_val = np.mean(cv_scores[metric])\n",
    "        std_val = np.std(cv_scores[metric])\n",
    "        ax.text(i + 1, mean_val + 0.02, f'{mean_val:.3f}±{std_val:.3f}', \n",
    "                ha='center', va='bottom', fontsize=8, fontweight='bold')\n",
    "    \n",
    "    ax.set_ylabel('Score')\n",
    "    ax.set_title('Cross-Validation Performance Stability\\n(Spatial 10-Fold CV)')\n",
    "    ax.set_ylim(0, 1.1)\n",
    "    ax.axhline(y=0.7, color='gray', linestyle='--', alpha=0.5, label='Acceptable threshold')\n",
    "    ax.legend(loc='lower right')\n",
    "    ax.grid(axis='y', alpha=0.3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_CV_Stability.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    print(\"  ✓ Figure 3 saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "66e54362",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 7. MODEL EVALUATION\n",
    "# =============================================================================\n",
    "\n",
    "def evaluate_model(model, X_test, y_test, features):\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"MODEL EVALUATION\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    y_pred = model.predict(X_test)\n",
    "    y_prob = model.predict_proba(X_test)[:, 1]\n",
    "    \n",
    "    metrics = {\n",
    "        'accuracy': accuracy_score(y_test, y_pred),\n",
    "        'precision': precision_score(y_test, y_pred),\n",
    "        'recall': recall_score(y_test, y_pred),\n",
    "        'f1': f1_score(y_test, y_pred),\n",
    "        'roc_auc': roc_auc_score(y_test, y_prob),\n",
    "        'kappa': cohen_kappa_score(y_test, y_pred)\n",
    "    }\n",
    "    \n",
    "    print(f\"\\n  Test Set Performance:\")\n",
    "    for m, v in metrics.items():\n",
    "        print(f\"    {m}: {v:.4f}\")\n",
    "    \n",
    "    cm = confusion_matrix(y_test, y_pred)\n",
    "    print(f\"\\n  Confusion Matrix:\")\n",
    "    print(f\"    TN={cm[0,0]}, FP={cm[0,1]}\")\n",
    "    print(f\"    FN={cm[1,0]}, TP={cm[1,1]}\")\n",
    "    \n",
    "    # Save tables\n",
    "    pd.DataFrame([metrics]).to_csv(os.path.join(Config.TABLES_DIR, 'Table_5_45_Performance.csv'), index=False)\n",
    "    pd.DataFrame(cm, columns=['Pred_NoLoss', 'Pred_Loss'], \n",
    "                 index=['Actual_NoLoss', 'Actual_Loss']).to_csv(\n",
    "        os.path.join(Config.TABLES_DIR, 'Table_5_46_Confusion_Matrix.csv'))\n",
    "    \n",
    "    # ROC curve\n",
    "    create_roc_curve(y_test, y_prob, metrics['roc_auc'])\n",
    "    \n",
    "    return metrics, cm\n",
    "\n",
    "\n",
    "def create_roc_curve(y_test, y_prob, auc_value):\n",
    "    fpr, tpr, thresholds = roc_curve(y_test, y_prob)\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(8, 7))\n",
    "    \n",
    "    ax.plot(fpr, tpr, 'b-', lw=2.5, label=f'Random Forest (AUC = {auc_value:.3f})')\n",
    "    ax.plot([0, 1], [0, 1], 'k--', lw=1.5, label='Random Classifier (AUC = 0.500)')\n",
    "    ax.fill_between(fpr, tpr, alpha=0.2, color='blue')\n",
    "    \n",
    "    j_scores = tpr - fpr\n",
    "    optimal_idx = np.argmax(j_scores)\n",
    "    optimal_threshold = thresholds[optimal_idx]\n",
    "    ax.scatter(fpr[optimal_idx], tpr[optimal_idx], marker='o', s=150, \n",
    "               color='red', zorder=5, label=f'Optimal Threshold ({optimal_threshold:.2f})')\n",
    "    \n",
    "    ax.set_xlabel('False Positive Rate (1 - Specificity)')\n",
    "    ax.set_ylabel('True Positive Rate (Sensitivity)')\n",
    "    ax.set_title('Receiver Operating Characteristic (ROC) Curve')\n",
    "    ax.legend(loc='lower right')\n",
    "    ax.grid(alpha=0.3)\n",
    "    ax.set_xlim([-0.02, 1.02])\n",
    "    ax.set_ylim([-0.02, 1.02])\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_ROC.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    print(\"  ✓ Figure 3 (ROC) saved\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "351e062f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 8. VARIABLE IMPORTANCE ANALYSIS (CORRECTED RANKING)\n",
    "# =============================================================================\n",
    "\n",
    "def analyze_importance(model, X_train, y_train, X_test, y_test, features):\n",
    "    \"\"\"\n",
    "    CORRECTED VERSION: Uses proper ranking with pandas rank()\n",
    "    \"\"\"\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"VARIABLE IMPORTANCE ANALYSIS\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    # Gini importance\n",
    "    gini_importance = model.feature_importances_\n",
    "    \n",
    "    # Permutation importance\n",
    "    print(\"\\n  Calculating permutation importance...\")\n",
    "    perm_result = permutation_importance(model, X_test, y_test, n_repeats=30, \n",
    "                                          random_state=Config.RANDOM_STATE, n_jobs=-1)\n",
    "    perm_importance = perm_result.importances_mean\n",
    "    perm_std = perm_result.importances_std\n",
    "    \n",
    "    # Create importance dataframe\n",
    "    importance_df = pd.DataFrame({\n",
    "        'Variable': features,\n",
    "        'Gini_Importance': gini_importance,\n",
    "        'Gini_Pct': gini_importance * 100,\n",
    "        'Permutation_Importance': perm_importance,\n",
    "        'Permutation_Std': perm_std,\n",
    "        'Category': [Config.VARIABLE_CATEGORIES.get(f, 'Unknown') for f in features]\n",
    "    })\n",
    "    \n",
    "    # CORRECTED RANKING using pandas rank()\n",
    "    # Higher importance = lower rank number (rank 1 = most important)\n",
    "    importance_df['Gini_Rank'] = importance_df['Gini_Importance'].rank(ascending=False).astype(int)\n",
    "    importance_df['Permutation_Rank'] = importance_df['Permutation_Importance'].rank(ascending=False).astype(int)\n",
    "    \n",
    "    # Mean rank\n",
    "    importance_df['Mean_Rank'] = (importance_df['Gini_Rank'] + importance_df['Permutation_Rank']) / 2\n",
    "    importance_df = importance_df.sort_values('Gini_Importance', ascending=False)\n",
    "    \n",
    "    print(\"\\n  Top 10 Variables (by Gini importance):\")\n",
    "    for i, row in importance_df.head(10).iterrows():\n",
    "        print(f\"    Rank {int(row['Gini_Rank']):2d}. {row['Variable']:20s} \"\n",
    "              f\"Gini={row['Gini_Importance']*100:.1f}%, Perm={row['Permutation_Importance']:.4f}\")\n",
    "    \n",
    "    # Category summary\n",
    "    category_summary = importance_df.groupby('Category').agg({\n",
    "        'Gini_Importance': 'sum',\n",
    "        'Variable': 'count'\n",
    "    }).rename(columns={'Variable': 'N_Variables'})\n",
    "    category_summary['Pct_Importance'] = category_summary['Gini_Importance'] * 100\n",
    "    category_summary = category_summary.sort_values('Gini_Importance', ascending=False)\n",
    "    \n",
    "    print(\"\\n  Category Summary:\")\n",
    "    for cat, row in category_summary.iterrows():\n",
    "        print(f\"    {cat:15s}: {row['Pct_Importance']:.1f}% ({int(row['N_Variables'])} vars)\")\n",
    "    \n",
    "    # Save tables\n",
    "    importance_df.to_csv(os.path.join(Config.TABLES_DIR, 'Table_5_48_Variable_Importance.csv'), index=False)\n",
    "    category_summary.to_csv(os.path.join(Config.TABLES_DIR, 'Table_5_49_Category_Summary.csv'))\n",
    "    \n",
    "    # Create plots\n",
    "    create_importance_plot(importance_df, features, gini_importance, perm_importance)\n",
    "    create_category_pie_chart(category_summary)\n",
    "    \n",
    "    return importance_df, category_summary\n",
    "\n",
    "\n",
    "def create_importance_plot(importance_df, features, gini_importance, perm_importance):\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(14, 8))\n",
    "    \n",
    "    # Sort by Gini importance\n",
    "    sorted_idx_gini = np.argsort(gini_importance)\n",
    "    sorted_features_gini = [features[i] for i in sorted_idx_gini]\n",
    "    sorted_gini = gini_importance[sorted_idx_gini]\n",
    "    \n",
    "    colors_gini = [Config.CATEGORY_COLORS.get(Config.VARIABLE_CATEGORIES.get(f, 'Unknown'), 'gray') \n",
    "                   for f in sorted_features_gini]\n",
    "    \n",
    "    # Left plot: Gini importance\n",
    "    bars1 = axes[0].barh(range(len(features)), sorted_gini, color=colors_gini, edgecolor='black', linewidth=0.5)\n",
    "    axes[0].set_yticks(range(len(features)))\n",
    "    axes[0].set_yticklabels([Config.VARIABLE_DISPLAY_NAMES.get(f, f) for f in sorted_features_gini])\n",
    "    axes[0].set_xlabel('Gini Importance (MDI)')\n",
    "    axes[0].set_title('(a) Mean Decrease Impurity')\n",
    "    axes[0].grid(axis='x', alpha=0.3)\n",
    "    \n",
    "    for i, (bar, val) in enumerate(zip(bars1, sorted_gini)):\n",
    "        axes[0].text(val + 0.005, bar.get_y() + bar.get_height()/2, \n",
    "                     f'{val*100:.1f}%', va='center', fontsize=8)\n",
    "    \n",
    "    # Sort by Permutation importance\n",
    "    sorted_idx_perm = np.argsort(perm_importance)\n",
    "    sorted_features_perm = [features[i] for i in sorted_idx_perm]\n",
    "    sorted_perm = perm_importance[sorted_idx_perm]\n",
    "    \n",
    "    colors_perm = [Config.CATEGORY_COLORS.get(Config.VARIABLE_CATEGORIES.get(f, 'Unknown'), 'gray') \n",
    "                   for f in sorted_features_perm]\n",
    "    \n",
    "    # Right plot: Permutation importance\n",
    "    bars2 = axes[1].barh(range(len(features)), sorted_perm, color=colors_perm, edgecolor='black', linewidth=0.5)\n",
    "    axes[1].set_yticks(range(len(features)))\n",
    "    axes[1].set_yticklabels([Config.VARIABLE_DISPLAY_NAMES.get(f, f) for f in sorted_features_perm])\n",
    "    axes[1].set_xlabel('Permutation Importance')\n",
    "    axes[1].set_title('(b) Mean Decrease Accuracy')\n",
    "    axes[1].grid(axis='x', alpha=0.3)\n",
    "    \n",
    "    # Add legend\n",
    "    from matplotlib.patches import Patch\n",
    "    legend_elements = [Patch(facecolor=color, edgecolor='black', label=cat) \n",
    "                       for cat, color in Config.CATEGORY_COLORS.items()]\n",
    "    fig.legend(handles=legend_elements, loc='lower center', ncol=5, bbox_to_anchor=(0.5, -0.02))\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.subplots_adjust(bottom=0.1)\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_Importance.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    print(\"  ✓ Figure 5a (importance bars) saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ef030e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# =============================================================================\n",
    "# 9. CATEGORY PIE CHART (manuscript Figure 5b)\n",
    "# =============================================================================\n",
    "\n",
    "def create_category_pie_chart(category_summary):\n",
    "    print(\"\\n--- Creating Category Pie Chart (Figure 4) ---\")\n",
    "    \n",
    "    fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "    \n",
    "    categories = category_summary.index.tolist()\n",
    "    values = category_summary['Pct_Importance'].values\n",
    "    colors = [Config.CATEGORY_COLORS.get(cat, 'gray') for cat in categories]\n",
    "    \n",
    "    explode = [0.05 if v == max(values) else 0 for v in values]\n",
    "    \n",
    "    wedges, texts, autotexts = axes[0].pie(\n",
    "        values, labels=categories, colors=colors,\n",
    "        autopct='%1.1f%%', explode=explode,\n",
    "        startangle=90, pctdistance=0.75, labeldistance=1.1\n",
    "    )\n",
    "    \n",
    "    for autotext in autotexts:\n",
    "        autotext.set_fontsize(10)\n",
    "        autotext.set_fontweight('bold')\n",
    "    \n",
    "    axes[0].set_title('(a) Relative Importance by Driver Category', fontsize=12, fontweight='bold')\n",
    "    \n",
    "    # Bar chart\n",
    "    y_pos = np.arange(len(categories))\n",
    "    sorted_idx = np.argsort(values)[::-1]\n",
    "    sorted_cats = [categories[i] for i in sorted_idx]\n",
    "    sorted_vals = [values[i] for i in sorted_idx]\n",
    "    sorted_colors = [colors[i] for i in sorted_idx]\n",
    "    n_vars = [category_summary.loc[cat, 'N_Variables'] for cat in sorted_cats]\n",
    "    \n",
    "    bars = axes[1].barh(y_pos, sorted_vals, color=sorted_colors, edgecolor='black', linewidth=0.5)\n",
    "    axes[1].set_yticks(y_pos)\n",
    "    axes[1].set_yticklabels([f\"{cat}\\n({int(n)} vars)\" for cat, n in zip(sorted_cats, n_vars)])\n",
    "    axes[1].set_xlabel('Cumulative Importance (%)')\n",
    "    axes[1].set_title('(b) Category Contribution Ranking', fontsize=12, fontweight='bold')\n",
    "    axes[1].grid(axis='x', alpha=0.3)\n",
    "    \n",
    "    for bar, val in zip(bars, sorted_vals):\n",
    "        axes[1].text(val + 0.5, bar.get_y() + bar.get_height()/2, \n",
    "                     f'{val:.1f}%', va='center', fontsize=10, fontweight='bold')\n",
    "    \n",
    "    axes[1].set_xlim(0, max(sorted_vals) * 1.15)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_Category_Importance.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    print(\"  ✓ Figure 5b (category pie chart) saved\")\n",
    "\n",
    "\n",
    "# =============================================================================\n",
    "# 10. DECISION TREE VISUALIZATION (manuscript Figure 8; depth limited to 4 levels)\n",
    "# =============================================================================\n",
    "\n",
    "def create_decision_tree_visualization(model, features):\n",
    "    print(\"\\n--- Creating Decision Tree Visualization (Figure 7) ---\")\n",
    "    \n",
    "    tree_depths = [tree.get_depth() for tree in model.estimators_]\n",
    "    median_depth = int(np.median(tree_depths))\n",
    "    \n",
    "    for i, tree in enumerate(model.estimators_):\n",
    "        if tree.get_depth() == median_depth:\n",
    "            selected_tree = tree\n",
    "            selected_idx = i\n",
    "            break\n",
    "    else:\n",
    "        selected_tree = model.estimators_[0]\n",
    "        selected_idx = 0\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(24, 16))\n",
    "    \n",
    "    feature_names = [Config.VARIABLE_DISPLAY_NAMES.get(f, f) for f in features]\n",
    "    class_names = ['No Loss', 'Habitat Loss']\n",
    "    \n",
    "    plot_tree(\n",
    "        selected_tree,\n",
    "        feature_names=feature_names,\n",
    "        class_names=class_names,\n",
    "        filled=True,\n",
    "        rounded=True,\n",
    "        ax=ax,\n",
    "        max_depth=4,\n",
    "        fontsize=9,\n",
    "        proportion=True,\n",
    "        impurity=False\n",
    "    )\n",
    "    \n",
    "    ax.set_title(f'Representative Decision Tree (Tree #{selected_idx + 1}, Depth Limited to 4 levels)\\n'\n",
    "                 f'Full tree depth: {selected_tree.get_depth()} levels', fontsize=14, fontweight='bold')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_Decision_Tree.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    # Save text rules\n",
    "    tree_rules = export_text(selected_tree, feature_names=features, max_depth=6)\n",
    "    with open(os.path.join(Config.OUTPUT_DIR, 'decision_tree_rules.txt'), 'w') as f:\n",
    "        f.write(f\"Decision Tree Rules (Tree #{selected_idx + 1})\\n\")\n",
    "        f.write(\"=\" * 60 + \"\\n\\n\")\n",
    "        f.write(tree_rules)\n",
    "    \n",
    "    print(f\"  ✓ Figure 8 saved (Tree #{selected_idx + 1}, full depth={selected_tree.get_depth()}, viz capped at 4)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "37fb417f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 11. THRESHOLD ANALYSIS - TRUE YOUDEN'S J (ROC-BASED)\n",
    "# =============================================================================\n",
    "\n",
    "def analyze_thresholds(model, X_train, y_train, features):\n",
    "    \"\"\"\n",
    "    CORRECTED VERSION: Uses TRUE ROC-based Youden's J statistic.\n",
    "    \n",
    "    Youden's J = Sensitivity + Specificity - 1 = TPR - FPR\n",
    "    \n",
    "    For each variable, we:\n",
    "    1. Treat it as a univariate predictor\n",
    "    2. Compute the ROC curve\n",
    "    3. Find the threshold that maximizes J = TPR - FPR\n",
    "    \"\"\"\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"THRESHOLD ANALYSIS (True Youden's J - ROC Based)\")\n",
    "    print(\"=\"*60)\n",
    "    print(\"\\nYouden's J = Sensitivity + Specificity - 1 = TPR - FPR\")\n",
    "    \n",
    "    # Get top features by importance\n",
    "    importance = model.feature_importances_\n",
    "    top_idx = np.argsort(-importance)[:6]\n",
    "    top_features = [features[i] for i in top_idx]\n",
    "    \n",
    "    threshold_results = []\n",
    "    \n",
    "    # Create figure for univariate ROC curves\n",
    "    fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
    "    axes = axes.flatten()\n",
    "    \n",
    "    for idx, feat in enumerate(top_features):\n",
    "        feat_idx = features.index(feat)\n",
    "        feat_values = X_train[:, feat_idx]\n",
    "        \n",
    "        # Determine direction using correlation\n",
    "        corr = np.corrcoef(feat_values, y_train)[0, 1]\n",
    "        \n",
    "        # If negative correlation, flip values for ROC (ROC assumes higher = positive)\n",
    "        if corr < 0:\n",
    "            feat_values_roc = -feat_values\n",
    "            direction = 'higher risk below'\n",
    "        else:\n",
    "            feat_values_roc = feat_values\n",
    "            direction = 'higher risk above'\n",
    "        \n",
    "        # Compute ROC curve using variable as univariate predictor\n",
    "        fpr, tpr, thresholds_roc = roc_curve(y_train, feat_values_roc)\n",
    "        roc_auc = auc(fpr, tpr)\n",
    "        \n",
    "        # Calculate Youden's J for each threshold point\n",
    "        # J = TPR - FPR = Sensitivity + Specificity - 1\n",
    "        youdens_j = tpr - fpr\n",
    "        \n",
    "        # Find optimal threshold (maximum J)\n",
    "        optimal_idx = np.argmax(youdens_j)\n",
    "        optimal_j = youdens_j[optimal_idx]\n",
    "        optimal_threshold_roc = thresholds_roc[optimal_idx]\n",
    "        \n",
    "        # Convert back to original scale if we flipped\n",
    "        if corr < 0:\n",
    "            optimal_threshold = -optimal_threshold_roc\n",
    "        else:\n",
    "            optimal_threshold = optimal_threshold_roc\n",
    "        \n",
    "        # Get sensitivity and specificity at optimal point\n",
    "        sensitivity = tpr[optimal_idx]\n",
    "        specificity = 1 - fpr[optimal_idx]\n",
    "        \n",
    "        # Calculate loss probabilities for interpretation\n",
    "        below_mask = feat_values <= optimal_threshold\n",
    "        above_mask = ~below_mask\n",
    "        prob_below = np.mean(y_train[below_mask]) if np.sum(below_mask) > 0 else np.nan\n",
    "        prob_above = np.mean(y_train[above_mask]) if np.sum(above_mask) > 0 else np.nan\n",
    "        \n",
    "        threshold_results.append({\n",
    "            'Variable': feat,\n",
    "            'Display_Name': Config.VARIABLE_DISPLAY_NAMES.get(feat, feat),\n",
    "            'Threshold': optimal_threshold,\n",
    "            'Youden_J': optimal_j,\n",
    "            'Sensitivity': sensitivity,\n",
    "            'Specificity': specificity,\n",
    "            'AUC': roc_auc,\n",
    "            'Prob_Below': prob_below,\n",
    "            'Prob_Above': prob_above,\n",
    "            'Direction': direction\n",
    "        })\n",
    "        \n",
    "        print(f\"\\n  {feat}:\")\n",
    "        print(f\"    Optimal threshold: {optimal_threshold:.2f}\")\n",
    "        print(f\"    Youden's J: {optimal_j:.3f}\")\n",
    "        print(f\"    Sensitivity: {sensitivity:.3f}, Specificity: {specificity:.3f}\")\n",
    "        print(f\"    Univariate AUC: {roc_auc:.3f}\")\n",
    "        print(f\"    P(loss|≤thresh): {prob_below*100:.1f}%\")\n",
    "        print(f\"    P(loss|>thresh): {prob_above*100:.1f}%\")\n",
    "        print(f\"    Direction: {direction}\")\n",
    "        \n",
    "        # Plot univariate ROC curve\n",
    "        ax = axes[idx]\n",
    "        ax.plot(fpr, tpr, 'b-', lw=2, label=f'ROC (AUC={roc_auc:.3f})')\n",
    "        ax.plot([0, 1], [0, 1], 'k--', lw=1, label='Random')\n",
    "        ax.scatter(fpr[optimal_idx], tpr[optimal_idx], s=100, c='red', zorder=5,\n",
    "                   label=f'Optimal (J={optimal_j:.3f})')\n",
    "        ax.fill_between(fpr, tpr, alpha=0.2, color='blue')\n",
    "        \n",
    "        # Draw vertical line showing Youden's J\n",
    "        ax.vlines(fpr[optimal_idx], fpr[optimal_idx], tpr[optimal_idx], \n",
    "                  colors='red', linestyles='dashed', alpha=0.7)\n",
    "        \n",
    "        ax.set_xlabel('FPR (1-Specificity)')\n",
    "        ax.set_ylabel('TPR (Sensitivity)')\n",
    "        ax.set_title(f'{Config.VARIABLE_DISPLAY_NAMES.get(feat, feat)}\\nThresh={optimal_threshold:.2f}')\n",
    "        ax.legend(loc='lower right', fontsize=8)\n",
    "        ax.grid(alpha=0.3)\n",
    "        ax.set_xlim([-0.02, 1.02])\n",
    "        ax.set_ylim([-0.02, 1.02])\n",
    "    \n",
    "    plt.suptitle(\"Univariate ROC Curves with Optimal Youden's J Thresholds\", \n",
    "                 fontsize=14, fontweight='bold', y=1.02)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_Univariate_ROC.png'), \n",
    "                dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    print(\"\\n  ✓ Figure 6 (univariate ROC + Youden's J) saved\")\n",
    "    \n",
    "    threshold_df = pd.DataFrame(threshold_results)\n",
    "    threshold_df.to_csv(os.path.join(Config.TABLES_DIR, 'Table_5_51_Thresholds.csv'), index=False)\n",
    "    \n",
    "    return threshold_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dafa2228",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 12. PARTIAL DEPENDENCE PLOTS (manuscript Figure 7)\n",
    "# =============================================================================\n",
    "\n",
    "def create_pdp(model, X_train, features):\n",
    "    print(\"\\n--- Creating Partial Dependence Plots (Figure 6) ---\")\n",
    "    \n",
    "    importance = model.feature_importances_\n",
    "    top_idx = np.argsort(-importance)[:6]\n",
    "    top_features = [features[i] for i in top_idx]\n",
    "    \n",
    "    fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
    "    axes = axes.flatten()\n",
    "    \n",
    "    display_names = [Config.VARIABLE_DISPLAY_NAMES.get(f, f) for f in features]\n",
    "    \n",
    "    for i, (feat, ax) in enumerate(zip(top_features, axes)):\n",
    "        feat_idx = features.index(feat)\n",
    "        PartialDependenceDisplay.from_estimator(\n",
    "            model, X_train, [feat_idx], feature_names=display_names,\n",
    "            ax=ax, kind='average', line_kw={'color': '#3498DB', 'linewidth': 2}\n",
    "        )\n",
    "        ax.set_title(f'{Config.VARIABLE_DISPLAY_NAMES.get(feat, feat)}', fontsize=11, fontweight='bold')\n",
    "        ax.grid(alpha=0.3)\n",
    "    \n",
    "    plt.suptitle('Partial Dependence Plots for Top 6 Drivers', fontsize=14, fontweight='bold', y=1.02)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(Config.FIGURES_DIR, 'Figure_PDP.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n",
    "    \n",
    "    print(\"  ✓ Figure 7 (PDPs) saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c454535d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 13. VALIDATION SUMMARY (Supplementary Table S1)\n",
    "# =============================================================================\n",
    "\n",
    "def create_validation_summary(metrics, cv_scores, model):\n",
    "    print(\"\\n--- Creating Validation Summary (Supplementary Table S1) ---\")\n",
    "    \n",
    "    validation_data = {\n",
    "        'Validation Aspect': [\n",
    "            'Test Set Accuracy',\n",
    "            'Test Set AUC-ROC',\n",
    "            'Test Set F1-Score',\n",
    "            'Test Set Kappa',\n",
    "            'CV Accuracy (mean ± std)',\n",
    "            'CV AUC-ROC (mean ± std)',\n",
    "            'OOB Score',\n",
    "            'Number of Trees',\n",
    "            'Number of Features',\n",
    "            'Training Samples',\n",
    "            'Test Samples',\n",
    "            'Spatial CV Folds'\n",
    "        ],\n",
    "        'Value': [\n",
    "            f\"{metrics['accuracy']:.4f}\",\n",
    "            f\"{metrics['roc_auc']:.4f}\",\n",
    "            f\"{metrics['f1']:.4f}\",\n",
    "            f\"{metrics['kappa']:.4f}\",\n",
    "            f\"{np.mean(cv_scores['accuracy']):.4f} ± {np.std(cv_scores['accuracy']):.4f}\",\n",
    "            f\"{np.mean(cv_scores['roc_auc']):.4f} ± {np.std(cv_scores['roc_auc']):.4f}\",\n",
    "            f\"{model.oob_score_:.4f}\",\n",
    "            str(Config.N_ESTIMATORS),\n",
    "            'N/A',\n",
    "            'See Table 5.44',\n",
    "            'See Table 5.44',\n",
    "            str(Config.CV_FOLDS)\n",
    "        ],\n",
    "        'Interpretation': [\n",
    "            'Good' if metrics['accuracy'] > 0.7 else 'Moderate',\n",
    "            'Excellent' if metrics['roc_auc'] > 0.8 else 'Good' if metrics['roc_auc'] > 0.7 else 'Moderate',\n",
    "            'Good' if metrics['f1'] > 0.7 else 'Moderate',\n",
    "            'Moderate' if metrics['kappa'] > 0.4 else 'Fair',\n",
    "            'Stable' if np.std(cv_scores['accuracy']) < 0.1 else 'Variable',\n",
    "            'Stable' if np.std(cv_scores['roc_auc']) < 0.1 else 'Variable',\n",
    "            'Good generalization' if model.oob_score_ > 0.7 else 'Moderate generalization',\n",
    "            'Sufficient for stable importance estimates',\n",
    "            'After VIF filtering',\n",
    "            'Stratified by habitat loss',\n",
    "            'Spatial block holdout',\n",
    "            'GroupKFold with spatial blocks'\n",
    "        ]\n",
    "    }\n",
    "    \n",
    "    validation_df = pd.DataFrame(validation_data)\n",
    "    validation_df.to_csv(os.path.join(Config.TABLES_DIR, 'Table_5_53_Validation_Summary.csv'), index=False)\n",
    "    \n",
    "    print(\"  ✓ Supplementary Table S1 saved\")\n",
    "    \n",
    "    return validation_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "68ff68a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 14. SAVE MODEL AND OUTPUTS\n",
    "# =============================================================================\n",
    "\n",
    "def save_model_and_metadata(model, features, metrics, cv_scores):\n",
    "    print(\"\\n--- Saving Model and Metadata ---\")\n",
    "    \n",
    "    joblib.dump(model, os.path.join(Config.OUTPUT_DIR, 'rf_model.joblib'))\n",
    "    \n",
    "    with open(os.path.join(Config.OUTPUT_DIR, 'model_features.txt'), 'w') as f:\n",
    "        f.write('\\n'.join(features))\n",
    "    \n",
    "    metadata = {\n",
    "        'timestamp': datetime.now().isoformat(),\n",
    "        'n_estimators': Config.N_ESTIMATORS,\n",
    "        'features': features,\n",
    "        'n_features': len(features),\n",
    "        'test_metrics': metrics,\n",
    "        'cv_accuracy_mean': float(np.mean(cv_scores['accuracy'])),\n",
    "        'cv_accuracy_std': float(np.std(cv_scores['accuracy'])),\n",
    "        'cv_auc_mean': float(np.mean(cv_scores['roc_auc'])),\n",
    "        'cv_auc_std': float(np.std(cv_scores['roc_auc'])),\n",
    "        'spatial_cv': Config.USE_SPATIAL_CV,\n",
    "        'rfe_used': Config.USE_RFE,\n",
    "        'oob_score': float(model.oob_score_),\n",
    "    }\n",
    "    \n",
    "    with open(os.path.join(Config.OUTPUT_DIR, 'model_metadata.json'), 'w') as f:\n",
    "        json.dump(metadata, f, indent=2)\n",
    "    \n",
    "    print(\"  ✓ Model and metadata saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "76a292d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "SCRIPT 3: RANDOM FOREST ANALYSIS (CORRECTED VERSION)\n",
      "============================================================\n",
      "\n",
      "CORRECTIONS APPLIED:\n",
      "  1. Variable importance ranking fixed (uses pandas rank())\n",
      "  2. Threshold analysis uses TRUE Youden's J (ROC-based)\n",
      "  3. Added: Sensitivity, Specificity, Univariate AUC\n",
      "  4. Added: Figure 5.37 (Univariate ROC curves)\n",
      "\n",
      "Configuration:\n",
      "  - Spatial CV: True\n",
      "  - RFE: True\n",
      "  - Variables excluded: ['years_protected']\n",
      "✓ Output directories ready\n",
      "\n",
      "--- Loading Data ---\n",
      "  Train samples: 7671\n",
      "  Test samples: 1751\n",
      "  Features: 14\n",
      "  Excluded: ['years_protected']\n",
      "\n",
      "--- Recursive Feature Elimination ---\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Selected features (10):\n",
      "    - dist_roads\n",
      "    - dist_settlements\n",
      "    - pop_density\n",
      "    - pop_change\n",
      "    - pct_cultivated\n",
      "    - protection_status\n",
      "    - elevation\n",
      "    - twi\n",
      "    - mean_rainfall\n",
      "    - mean_temp\n",
      "\n",
      "  Note: Using all 14 features for main analysis\n",
      "  RFE selected features saved separately\n",
      "\n",
      "============================================================\n",
      "TRAINING RANDOM FOREST MODEL\n",
      "============================================================\n",
      "\n",
      "  Training with 7671 samples...\n",
      "\n",
      "  ✓ Training complete\n",
      "    Training Accuracy: 0.9193\n",
      "    OOB Score: 0.8526\n",
      "\n",
      "--- Spatial Cross-Validation ---\n",
      "  Spatial CV Results (18 blocks, 10 folds):\n",
      "    accuracy: 0.8393 ± 0.0678\n",
      "    precision: 0.7747 ± 0.1102\n",
      "    recall: 0.7291 ± 0.2485\n",
      "    f1: 0.7369 ± 0.1841\n",
      "    roc_auc: 0.8390 ± 0.0689\n",
      "\n",
      "--- Creating CV Stability Plot (Figure 5.32) ---\n",
      "  ✓ Figure 5.32 saved\n",
      "\n",
      "============================================================\n",
      "MODEL EVALUATION\n",
      "============================================================\n",
      "\n",
      "  Test Set Performance:\n",
      "    accuracy: 0.7727\n",
      "    precision: 0.7593\n",
      "    recall: 0.8641\n",
      "    f1: 0.8083\n",
      "    roc_auc: 0.8284\n",
      "    kappa: 0.5320\n",
      "\n",
      "  Confusion Matrix:\n",
      "    TN=514, FP=266\n",
      "    FN=132, TP=839\n",
      "  ✓ Figure 5.31 (ROC) saved\n",
      "\n",
      "============================================================\n",
      "VARIABLE IMPORTANCE ANALYSIS\n",
      "============================================================\n",
      "\n",
      "  Calculating permutation importance...\n",
      "\n",
      "  Top 10 Variables (by Gini importance):\n",
      "    Rank  1. pct_cultivated       Gini=33.6%, Perm=0.2019\n",
      "    Rank  2. mean_rainfall        Gini=12.0%, Perm=-0.0014\n",
      "    Rank  3. dist_settlements     Gini=11.9%, Perm=0.0066\n",
      "    Rank  4. pop_density          Gini=10.0%, Perm=-0.0032\n",
      "    Rank  5. mean_temp            Gini=6.7%, Perm=0.0001\n",
      "    Rank  6. protection_status    Gini=6.0%, Perm=-0.0008\n",
      "    Rank  7. pop_change           Gini=3.7%, Perm=-0.0015\n",
      "    Rank  8. dist_roads           Gini=3.2%, Perm=-0.0029\n",
      "    Rank  9. elevation            Gini=3.1%, Perm=-0.0030\n",
      "    Rank 10. dist_rivers          Gini=2.3%, Perm=0.0008\n",
      "\n",
      "  Category Summary:\n",
      "    Socio-economic : 47.3% (3 vars)\n",
      "    Climatic       : 18.7% (2 vars)\n",
      "    Proximity      : 18.5% (4 vars)\n",
      "    Topographic    : 9.4% (4 vars)\n",
      "    Conservation   : 6.0% (1 vars)\n",
      "  ✓ Figure 5.33 (Importance) saved\n",
      "\n",
      "--- Creating Category Pie Chart (Figure 5.34) ---\n",
      "  ✓ Figure 5.34 saved\n",
      "\n",
      "--- Creating Decision Tree Visualization (Figure 5.35) ---\n",
      "  ✓ Figure 5.35 saved (Tree #7, depth=21)\n",
      "\n",
      "============================================================\n",
      "THRESHOLD ANALYSIS (True Youden's J - ROC Based)\n",
      "============================================================\n",
      "\n",
      "Youden's J = Sensitivity + Specificity - 1 = TPR - FPR\n",
      "\n",
      "  pct_cultivated:\n",
      "    Optimal threshold: 10.03\n",
      "    Youden's J: 0.683\n",
      "    Sensitivity: 0.891, Specificity: 0.793\n",
      "    Univariate AUC: 0.911\n",
      "    P(loss|≤thresh): 11.3%\n",
      "    P(loss|>thresh): 79.9%\n",
      "    Direction: higher risk above\n",
      "\n",
      "  mean_rainfall:\n",
      "    Optimal threshold: 897.58\n",
      "    Youden's J: 0.422\n",
      "    Sensitivity: 0.791, Specificity: 0.631\n",
      "    Univariate AUC: 0.761\n",
      "    P(loss|≤thresh): 66.5%\n",
      "    P(loss|>thresh): 23.4%\n",
      "    Direction: higher risk below\n",
      "\n",
      "  dist_settlements:\n",
      "    Optimal threshold: 4000.92\n",
      "    Youden's J: 0.491\n",
      "    Sensitivity: 0.708, Specificity: 0.783\n",
      "    Univariate AUC: 0.817\n",
      "    P(loss|≤thresh): 75.1%\n",
      "    P(loss|>thresh): 25.7%\n",
      "    Direction: higher risk below\n",
      "\n",
      "  pop_density:\n",
      "    Optimal threshold: 0.04\n",
      "    Youden's J: 0.468\n",
      "    Sensitivity: 0.818, Specificity: 0.650\n",
      "    Univariate AUC: 0.808\n",
      "    P(loss|≤thresh): 20.6%\n",
      "    P(loss|>thresh): 68.4%\n",
      "    Direction: higher risk above\n",
      "\n",
      "  mean_temp:\n",
      "    Optimal threshold: 21.36\n",
      "    Youden's J: 0.303\n",
      "    Sensitivity: 0.378, Specificity: 0.925\n",
      "    Univariate AUC: 0.693\n",
      "    P(loss|≤thresh): 82.3%\n",
      "    P(loss|>thresh): 38.3%\n",
      "    Direction: higher risk below\n",
      "\n",
      "  protection_status:\n",
      "    Optimal threshold: 1.00\n",
      "    Youden's J: 0.323\n",
      "    Sensitivity: 0.979, Specificity: 0.344\n",
      "    Univariate AUC: 0.703\n",
      "    P(loss|≤thresh): 58.0%\n",
      "    P(loss|>thresh): 5.4%\n",
      "    Direction: higher risk below\n",
      "\n",
      "  ✓ Figure 5.37 (Univariate ROC) saved\n",
      "\n",
      "--- Creating Partial Dependence Plots (Figure 5.36) ---\n",
      "  ✓ Figure 5.36 (PDPs) saved\n",
      "\n",
      "--- Creating Validation Summary (Table 5.53) ---\n",
      "  ✓ Table 5.53 saved\n",
      "\n",
      "--- Saving Model and Metadata ---\n",
      "  ✓ Model and metadata saved\n",
      "\n",
      "============================================================\n",
      "ANALYSIS COMPLETE\n",
      "============================================================\n",
      "\n",
      "Outputs saved to:\n",
      "  Tables: D:/Publication/Habitat Loss in the Greater Kafue Ecosystem/Thesis_anlysis/Final_results for defence/outputs/objective_b\\tables\n",
      "  Figures: D:/Publication/Habitat Loss in the Greater Kafue Ecosystem/Thesis_anlysis/Final_results for defence/outputs/objective_b\\figures\n",
      "  Model: D:/Publication/Habitat Loss in the Greater Kafue Ecosystem/Thesis_anlysis/Final_results for defence/outputs/objective_b\\results\n",
      "\n",
      "------------------------------------------------------------\n",
      "FIGURES GENERATED:\n",
      "------------------------------------------------------------\n",
      "  Figure 5.31: ROC Curve\n",
      "  Figure 5.32: CV Stability Plot\n",
      "  Figure 5.33: Variable Importance Comparison\n",
      "  Figure 5.34: Category Importance Pie Chart\n",
      "  Figure 5.35: Decision Tree Visualization\n",
      "  Figure 5.36: Partial Dependence Plots\n",
      "  Figure 5.37: Univariate ROC Curves (NEW)\n",
      "\n",
      "------------------------------------------------------------\n",
      "TABLES GENERATED:\n",
      "------------------------------------------------------------\n",
      "  Table 5.45: Model Performance Metrics\n",
      "  Table 5.46: Confusion Matrix\n",
      "  Table 5.47: CV Results\n",
      "  Table 5.48: Variable Importance (CORRECTED RANKS)\n",
      "  Table 5.49: Category Summary\n",
      "  Table 5.51: Critical Thresholds (TRUE YOUDEN'S J)\n",
      "  Table 5.53: Validation Summary\n",
      "  Table RFE_Ranking: Feature Selection Results\n"
     ]
    }
   ],
   "source": [
    "# =============================================================================\n",
    "# 15. MAIN EXECUTION\n",
    "# =============================================================================\n",
    "\n",
    "def main():\n",
    "    print(\"=\"*60)\n",
    "    print(\"SCRIPT 3: RANDOM FOREST ANALYSIS (CORRECTED VERSION)\")\n",
    "    print(\"=\"*60)\n",
    "    print(\"\\nCORRECTIONS APPLIED:\")\n",
    "    print(\"  1. Variable importance ranking fixed (uses pandas rank())\")\n",
    "    print(\"  2. Threshold analysis uses TRUE Youden's J (ROC-based)\")\n",
    "    print(\"  3. Added: Sensitivity, Specificity, Univariate AUC\")\n",
    "    print(\"  4. Added: Figure 5 (Univariate ROC curves)\")\n",
    "    \n",
    "    print(f\"\\nConfiguration:\")\n",
    "    print(f\"  - Spatial CV: {Config.USE_SPATIAL_CV}\")\n",
    "    print(f\"  - RFE: {Config.USE_RFE}\")\n",
    "    print(f\"  - Variables excluded: {Config.VARS_TO_EXCLUDE}\")\n",
    "    \n",
    "    create_output_directories()\n",
    "    \n",
    "    # Load data\n",
    "    train_df, test_df, features = load_data()\n",
    "    \n",
    "    # Prepare arrays\n",
    "    X_train = train_df[features].values\n",
    "    y_train = train_df['habitat_loss'].values\n",
    "    X_test = test_df[features].values\n",
    "    y_test = test_df['habitat_loss'].values\n",
    "    \n",
    "    # Feature selection with RFE\n",
    "    if Config.USE_RFE:\n",
    "        selected_features, rfe_ranking = perform_rfe(X_train, y_train, features, Config.RFE_N_FEATURES)\n",
    "        rfe_ranking.to_csv(os.path.join(Config.TABLES_DIR, 'Table_RFE_Ranking.csv'), index=False)\n",
    "        print(f\"\\n  Note: Using all {len(features)} features for main analysis\")\n",
    "        print(f\"  RFE selected features saved separately\")\n",
    "    \n",
    "    # Train model\n",
    "    model = train_model(X_train, y_train)\n",
    "    \n",
    "    # Cross-validation\n",
    "    if Config.USE_SPATIAL_CV and 'spatial_block' in train_df.columns:\n",
    "        groups = train_df['spatial_block'].values\n",
    "        cv_scores = spatial_cross_validation(model, X_train, y_train, groups, Config.CV_FOLDS)\n",
    "    else:\n",
    "        cv_scores = standard_cross_validation(model, X_train, y_train, Config.CV_FOLDS)\n",
    "    \n",
    "    # Save CV results\n",
    "    cv_df = pd.DataFrame(cv_scores)\n",
    "    cv_df.to_csv(os.path.join(Config.TABLES_DIR, 'Table_5_47_CV_Results.csv'), index=False)\n",
    "    \n",
    "    # CV stability plot\n",
    "    create_cv_stability_plot(cv_scores)\n",
    "    \n",
    "    # Model evaluation\n",
    "    metrics, cm = evaluate_model(model, X_test, y_test, features)\n",
    "    \n",
    "    # Variable importance (CORRECTED)\n",
    "    importance_df, category_summary = analyze_importance(model, X_train, y_train, X_test, y_test, features)\n",
    "    \n",
    "    # Decision tree visualization\n",
    "    create_decision_tree_visualization(model, features)\n",
    "    \n",
    "    # Threshold analysis (CORRECTED - True Youden's J)\n",
    "    threshold_df = analyze_thresholds(model, X_train, y_train, features)\n",
    "    \n",
    "    # Partial dependence plots\n",
    "    create_pdp(model, X_train, features)\n",
    "    \n",
    "    # Validation summary\n",
    "    validation_summary = create_validation_summary(metrics, cv_scores, model)\n",
    "    \n",
    "    # Save model\n",
    "    save_model_and_metadata(model, features, metrics, cv_scores)\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"ANALYSIS COMPLETE\")\n",
    "    print(\"=\"*60)\n",
    "    print(f\"\\nOutputs saved to:\")\n",
    "    print(f\"  Tables: {Config.TABLES_DIR}\")\n",
    "    print(f\"  Figures: {Config.FIGURES_DIR}\")\n",
    "    print(f\"  Model: {Config.OUTPUT_DIR}\")\n",
    "    \n",
    "    print(\"\\n\" + \"-\"*60)\n",
    "    print(\"FIGURES GENERATED:\")\n",
    "    print(\"-\"*60)\n",
    "    print(\"  Figure 3: ROC Curve\")\n",
    "    print(\"  Figure 4: CV Stability Plot\")\n",
    "    print(\"  Figure 5: Variable Importance Comparison\")\n",
    "    print(\"  Figure 5: Category Importance Pie Chart\")\n",
    "    print(\"  Figure 8: Decision Tree Visualization (depth limited to 4 levels)\")\n",
    "    print(\"  Figure 7: Partial Dependence Plots\")\n",
    "    print(\"  Figure 6: Univariate ROC Curves\")\n",
    "    \n",
    "    print(\"\\n\" + \"-\"*60)\n",
    "    print(\"TABLES GENERATED:\")\n",
    "    print(\"-\"*60)\n",
    "    print(\"  Supplementary Table S1: Model Performance and Validation Summary\")\n",
    "    print(\"  Supplementary Table S2: Variable Importance Rankings\")\n",
    "    print(\"  Manuscript Table 3: Confusion Matrix\")\n",
    "    print(\"  Manuscript Table 4: Critical Thresholds (true Youden's J)\")\n",
    "    print(\"  Internal CSVs: CV Results, Category Summary, RFE Ranking\")\n",
    "    \n",
    "    return {\n",
    "        'model': model, \n",
    "        'features': features, \n",
    "        'metrics': metrics, \n",
    "        'importance': importance_df,\n",
    "        'cv_scores': cv_scores,\n",
    "        'threshold': threshold_df,\n",
    "        'validation_summary': validation_summary\n",
    "    }\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    results = main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
