{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Supplementary Script S2: Training Sample Generation and Data Extraction\n",
    "\n",
    "**Study:** Assessing Drivers of Forest Loss as Indicators of Habitat Degradation in the Greater Kafue Ecosystem of Zambia Using a Random Forest Approach  \n",
    "**Authors:** Gift Mulenga, Darius Phiri, Matamyo Simwanda, Vincent R. Nyirenda \n",
    "\n",
    "---\n",
    "\n",
    "## Purpose\n",
    "\n",
    "This script (a) defines the binary response variable (habitat loss vs. persistence), (b) generates 10,000 spatially balanced sample points, (c) extracts predictor raster values at each point, (d) screens predictors for multicollinearity (VIF), and (e) partitions samples into spatially blocked training and test sets.\n",
    "\n",
    "## Inputs Required\n",
    "\n",
    "| Input | Source |\n",
    "|-------|--------|\n",
    "| LULC raster 1984 | Local file: `GKE_1984.tif` (from companion mapping paper, Mulenga et al. in press) |\n",
    "| LULC raster 2024 | Local file: `GKE_2024.tif` |\n",
    "| 14 driver rasters | Exported from Script S1, organised in `DRIVER_DIR` by category |\n",
    "\n",
    "## Outputs Produced\n",
    "\n",
    "| Output | Description |\n",
    "|--------|-------------|\n",
    "| `habitat_loss_mask.tif` | Binary raster: 1 = loss, 2 = persistence, 0 = excluded |\n",
    "| `sample_points.shp` | Spatially thinned sample point locations |\n",
    "| `driver_data_full.csv` | Full extracted dataset (all samples) |\n",
    "| `train_data.csv` | Training set (~80% of samples, spatially blocked) |\n",
    "| `test_data.csv` | Test set (~20% of samples, held-out spatial blocks) |\n",
    "| `correlation_matrix.csv` | Pairwise Pearson correlations among predictors |\n",
    "| `vif_analysis.csv` | Variance Inflation Factors for multicollinearity screening |\n",
    "| `sample_metadata.json` | Record of all configuration parameters for reproducibility |\n",
    "\n",
    "## How to Use\n",
    "\n",
    "> **⚠ USER ACTION REQUIRED:** Update `Config.LULC_DIR`, `Config.DRIVER_DIR`, and `Config.OUTPUT_DIR` before running.\n",
    "\n",
    "1. Ensure Script S1 has been run and all driver rasters are downloaded to `DRIVER_DIR`\n",
    "2. Update paths in `Config` (Section 1)\n",
    "3. Run all cells in order — the script is self-contained and produces all necessary outputs\n",
    "4. Outputs in `samples/` are read directly by Script S3\n",
    "\n",
    "## Scientific Notes\n",
    "\n",
    "- **Response variable:** Pixels transitioning from *natural* cover (Forest=2, Grassland=4, Water=6) in 1984 to *disturbed* cover (Built-up=1, Cropland=3, Bareland=5) in 2024 are coded as habitat loss (1); pixels retaining natural cover are coded as persistence (0).  \n",
    "  *Water class (6) is retained as natural habitat for consistency with the companion land-cover mapping paper (Mulenga et al., in press).*\n",
    "- **Sample size:** The minimum required sample size is estimated from Equation 1 (manuscript):  \n",
    "  `n = Z²·p(1-p)·D / E²` where Z=1.96, p=0.5, D=1.5, E=0.05 → n ≈ 576.  \n",
    "  We generated 10,000 points to achieve high precision and robust model training.\n",
    "- **Spatial thinning:** A 500 m minimum inter-point distance is enforced using a k-d tree to reduce spatial autocorrelation among samples before model fitting.\n",
    "- **Spatial blocking:** Spatial cross-validation using a 5×5 block grid prevents spatial data leakage — training and test samples from nearby locations would share correlated predictor values, artificially inflating model performance. See Roberts et al. (2017) [ref 31].\n",
    "- **VIF screening:** `years_protected` was excluded a priori from the final model because it is a *deterministic function* of `protection_status` (0 = unprotected → 0 yrs; 1 = GMA → 52 yrs; 2 = KNP → 74 yrs), i.e. perfectly collinear by construction. VIF analysis was retained as a confirmatory diagnostic only, yielding the final 14-predictor set.\n",
    "- **Random seed 42** is fixed throughout for full reproducibility.\n",
    "\n",
    "## Dependencies\n",
    "\n",
    "```\n",
    "numpy >= 1.23.0\n",
    "pandas >= 1.5.0\n",
    "geopandas >= 0.12.0\n",
    "rasterio >= 1.3.0\n",
    "scipy >= 1.9.0\n",
    "scikit-learn >= 1.1.0\n",
    "shapely >= 2.0.0\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4510baa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import geopandas as gpd\n",
    "import rasterio\n",
    "from rasterio.warp import reproject, Resampling\n",
    "from shapely.geometry import Point\n",
    "from scipy.spatial import cKDTree\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import json\n",
    "from datetime import datetime\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d9bf77a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 1. CONFIGURATION\n",
    "# =============================================================================\n",
    "\n",
    "class Config:\n",
    "    # UPDATE THESE PATHS\n",
    "    LULC_DIR = r\"D:/Publication/Habitat Loss in the Greater Kafue Ecosystem/Thesis_anlysis/Final_results for defence/data/data\"\n",
    "    DRIVER_DIR = r\"E:/Gift_Mulenga/Research/Msc_Tropical_Ecology/GKE_Objective_B/data\"\n",
    "    OUTPUT_DIR = r\"D:/Publication/Habitat Loss in the Greater Kafue Ecosystem/Thesis_anlysis/Final_results for defence/outputs/objective_b\"\n",
    "    \n",
    "    # REVISED: Include Water (6) in natural habitat for consistency with Chapter 4\n",
    "    NATURAL_CLASSES = [2, 4, 6]  # Forest, Grassland, Water\n",
    "    DISTURBED_CLASSES = [1, 3, 5]  # Built-up, Cropland, Bareland\n",
    "    \n",
    "    CLASS_NAMES = {1: 'Built-up', 2: 'Forest', 3: 'Cropland', \n",
    "                   4: 'Grassland', 5: 'Bareland', 6: 'Water'}\n",
    "    \n",
    "    TOTAL_SAMPLES = 10000\n",
    "    LOSS_SAMPLES = 5000\n",
    "    NO_LOSS_SAMPLES = 5000\n",
    "    MIN_DISTANCE = 500\n",
    "    RANDOM_STATE = 42\n",
    "    TEST_SIZE = 0.20\n",
    "    \n",
    "    # NEW: Spatial blocking parameters\n",
    "    USE_SPATIAL_BLOCKING = True\n",
    "    N_SPATIAL_BLOCKS = 25  # 5x5 grid\n",
    "    \n",
    "    # REVISED: Added dist_knp\n",
    "    DRIVER_FILES = {\n",
    "        'dist_roads': 'proximity/dist_roads.tif',\n",
    "        'dist_settlements': 'proximity/dist_settlements.tif',\n",
    "        'dist_rivers': 'proximity/dist_rivers.tif',\n",
    "        'dist_knp': 'proximity/dist_knp.tif',  # NEW\n",
    "        'pop_density': 'socioeconomic/pop_density.tif',\n",
    "        'pop_change': 'socioeconomic/pop_change.tif',\n",
    "        'pct_cultivated': 'socioeconomic/pct_cultivated.tif',\n",
    "        'protection_status': 'conservation/protection_status.tif',\n",
    "        'years_protected': 'conservation/years_protected.tif',  # Will be dropped\n",
    "        'elevation': 'topographic/elevation.tif',\n",
    "        'slope': 'topographic/slope.tif',\n",
    "        'aspect': 'topographic/aspect.tif',\n",
    "        'twi': 'topographic/twi.tif',\n",
    "        'mean_rainfall': 'climatic/mean_rainfall.tif',\n",
    "        'mean_temp': 'climatic/mean_temp.tif',\n",
    "    }\n",
    "    \n",
    "\n",
    "    VARS_TO_DROP_VIF = ['years_protected']\n",
    "    \n",
    "    VARIABLE_CATEGORIES = {\n",
    "        'dist_roads': 'Proximity', 'dist_settlements': 'Proximity',\n",
    "        'dist_rivers': 'Proximity', 'dist_knp': 'Proximity',\n",
    "        'pop_density': 'Socio-economic', 'pop_change': 'Socio-economic',\n",
    "        'pct_cultivated': 'Socio-economic', 'protection_status': 'Conservation',\n",
    "        'years_protected': 'Conservation', 'elevation': 'Topographic',\n",
    "        'slope': 'Topographic', 'aspect': 'Topographic', 'twi': 'Topographic',\n",
    "        'mean_rainfall': 'Climatic', 'mean_temp': 'Climatic',\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2df3d17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 2. UTILITY FUNCTIONS\n",
    "# =============================================================================\n",
    "\n",
    "def create_output_directories():\n",
    "    dirs = [Config.OUTPUT_DIR, os.path.join(Config.OUTPUT_DIR, 'samples'),\n",
    "            os.path.join(Config.OUTPUT_DIR, 'metadata'),\n",
    "            os.path.join(Config.OUTPUT_DIR, 'figures'),\n",
    "            os.path.join(Config.OUTPUT_DIR, 'tables')]\n",
    "    for d in dirs:\n",
    "        os.makedirs(d, exist_ok=True)\n",
    "    print(\"✓ Output directories ready\")\n",
    "\n",
    "\n",
    "def align_raster_to_reference(source_path, reference_path, output_path):\n",
    "    print(f\"\\n--- Aligning Raster to Reference ---\")\n",
    "    with rasterio.open(reference_path) as ref:\n",
    "        ref_transform, ref_crs, ref_shape = ref.transform, ref.crs, ref.shape\n",
    "        \n",
    "    with rasterio.open(source_path) as src:\n",
    "        dst_array = np.empty(ref_shape, dtype=src.dtypes[0])\n",
    "        reproject(source=rasterio.band(src, 1), destination=dst_array,\n",
    "                  src_transform=src.transform, src_crs=src.crs,\n",
    "                  dst_transform=ref_transform, dst_crs=ref_crs,\n",
    "                  resampling=Resampling.nearest)\n",
    "        \n",
    "        profile = src.profile.copy()\n",
    "        profile.update({'height': ref_shape[0], 'width': ref_shape[1],\n",
    "                        'transform': ref_transform, 'crs': ref_crs})\n",
    "        \n",
    "        with rasterio.open(output_path, 'w', **profile) as dst:\n",
    "            dst.write(dst_array, 1)\n",
    "    \n",
    "    print(f\"  ✓ Alignment complete: {output_path}\")\n",
    "    return output_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2513b657",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 3. HABITAT LOSS MASK (REVISED - INCLUDING WATER)\n",
    "# =============================================================================\n",
    "\n",
    "def create_habitat_loss_mask():\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"CREATING HABITAT LOSS MASK (REVISED - INCLUDES WATER)\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    lulc_1984_path = os.path.join(Config.LULC_DIR, \"GKE_1984.tif\")\n",
    "    lulc_2024_path = os.path.join(Config.LULC_DIR, \"GKE_2024.tif\")\n",
    "    \n",
    "    with rasterio.open(lulc_1984_path) as src_1984:\n",
    "        lulc_1984 = src_1984.read(1)\n",
    "        profile = src_1984.profile.copy()\n",
    "        transform, crs, res = src_1984.transform, src_1984.crs, src_1984.res[0]\n",
    "        \n",
    "    with rasterio.open(lulc_2024_path) as src_2024:\n",
    "        lulc_2024 = src_2024.read(1)\n",
    "    \n",
    "    # Align if needed\n",
    "    if lulc_1984.shape != lulc_2024.shape:\n",
    "        print(f\"  ⚠ Dimension mismatch - aligning...\")\n",
    "        aligned_path = os.path.join(Config.OUTPUT_DIR, \"GKE_2024_aligned.tif\")\n",
    "        align_raster_to_reference(lulc_2024_path, lulc_1984_path, aligned_path)\n",
    "        with rasterio.open(aligned_path) as src:\n",
    "            lulc_2024 = src.read(1)\n",
    "    \n",
    "    print(f\"\\n  Natural classes: {Config.NATURAL_CLASSES} (Forest, Grassland, Water)\")\n",
    "    print(f\"  Disturbed classes: {Config.DISTURBED_CLASSES}\")\n",
    "    \n",
    "    natural_1984 = np.isin(lulc_1984, Config.NATURAL_CLASSES)\n",
    "    natural_2024 = np.isin(lulc_2024, Config.NATURAL_CLASSES)\n",
    "    disturbed_2024 = np.isin(lulc_2024, Config.DISTURBED_CLASSES)\n",
    "    \n",
    "    loss_mask = natural_1984 & disturbed_2024\n",
    "    no_loss_mask = natural_1984 & natural_2024\n",
    "    \n",
    "    print(f\"\\n  Natural pixels 1984: {np.sum(natural_1984):,}\")\n",
    "    print(f\"  Loss pixels: {np.sum(loss_mask):,}\")\n",
    "    print(f\"  No-loss pixels: {np.sum(no_loss_mask):,}\")\n",
    "    print(f\"  Loss rate: {np.sum(loss_mask)/np.sum(natural_1984)*100:.2f}%\")\n",
    "    \n",
    "    combined_mask = np.zeros_like(lulc_1984, dtype=np.uint8)\n",
    "    combined_mask[loss_mask] = 1\n",
    "    combined_mask[no_loss_mask] = 2\n",
    "    \n",
    "    output_path = os.path.join(Config.OUTPUT_DIR, \"habitat_loss_mask.tif\")\n",
    "    profile.update(dtype=rasterio.uint8, count=1, nodata=0)\n",
    "    with rasterio.open(output_path, 'w', **profile) as dst:\n",
    "        dst.write(combined_mask, 1)\n",
    "    \n",
    "    print(f\"  ✓ Loss mask saved: {output_path}\")\n",
    "    \n",
    "    return {'mask': combined_mask, 'transform': transform, 'crs': crs,\n",
    "            'shape': lulc_1984.shape, 'loss_pixels': np.sum(loss_mask),\n",
    "            'no_loss_pixels': np.sum(no_loss_mask)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05906143",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 4. STRATIFIED SAMPLING\n",
    "# =============================================================================\n",
    "\n",
    "def generate_stratified_samples(mask_data):\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"GENERATING STRATIFIED RANDOM SAMPLES\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    mask, transform = mask_data['mask'], mask_data['transform']\n",
    "    \n",
    "    loss_indices = np.where(mask == 1)\n",
    "    no_loss_indices = np.where(mask == 2)\n",
    "    \n",
    "    np.random.seed(Config.RANDOM_STATE)\n",
    "    \n",
    "    loss_idx = np.random.choice(len(loss_indices[0]), \n",
    "                                min(Config.LOSS_SAMPLES, len(loss_indices[0])), replace=False)\n",
    "    no_loss_idx = np.random.choice(len(no_loss_indices[0]),\n",
    "                                   min(Config.NO_LOSS_SAMPLES, len(no_loss_indices[0])), replace=False)\n",
    "    \n",
    "    def rowcol_to_coords(rows, cols, t):\n",
    "        xs = t.c + cols * t.a + rows * t.b\n",
    "        ys = t.f + cols * t.d + rows * t.e\n",
    "        return xs, ys\n",
    "    \n",
    "    loss_x, loss_y = rowcol_to_coords(loss_indices[0][loss_idx], loss_indices[1][loss_idx], transform)\n",
    "    no_loss_x, no_loss_y = rowcol_to_coords(no_loss_indices[0][no_loss_idx], no_loss_indices[1][no_loss_idx], transform)\n",
    "    \n",
    "    all_x = np.concatenate([loss_x, no_loss_x])\n",
    "    all_y = np.concatenate([loss_y, no_loss_y])\n",
    "    all_class = np.concatenate([np.ones(len(loss_x)), np.zeros(len(no_loss_x))])\n",
    "    all_rows = np.concatenate([loss_indices[0][loss_idx], no_loss_indices[0][no_loss_idx]])\n",
    "    all_cols = np.concatenate([loss_indices[1][loss_idx], no_loss_indices[1][no_loss_idx]])\n",
    "    \n",
    "    # Minimum distance filter\n",
    "    print(f\"\\n--- Applying {Config.MIN_DISTANCE}m Minimum Distance Filter ---\")\n",
    "    coords = np.column_stack([all_x, all_y])\n",
    "    tree = cKDTree(coords)\n",
    "    keep_mask = np.ones(len(coords), dtype=bool)\n",
    "    \n",
    "    for i in range(len(coords)):\n",
    "        if not keep_mask[i]:\n",
    "            continue\n",
    "        neighbors = tree.query_ball_point(coords[i], Config.MIN_DISTANCE)\n",
    "        for j in neighbors:\n",
    "            if j != i and keep_mask[j]:\n",
    "                keep_mask[j] = False\n",
    "    \n",
    "    print(f\"  Before: {len(all_x)}, After: {np.sum(keep_mask)}\")\n",
    "    \n",
    "    return {'x': all_x[keep_mask], 'y': all_y[keep_mask], 'class': all_class[keep_mask],\n",
    "            'rows': all_rows[keep_mask], 'cols': all_cols[keep_mask], 'crs': mask_data['crs']}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1298e0c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 5. DRIVER VALUE EXTRACTION\n",
    "# =============================================================================\n",
    "\n",
    "def extract_driver_values(sample_data):\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"EXTRACTING DRIVER VALUES\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    df = pd.DataFrame({'x': sample_data['x'], 'y': sample_data['y'],\n",
    "                       'habitat_loss': sample_data['class'].astype(int)})\n",
    "    \n",
    "    for var_name, rel_path in Config.DRIVER_FILES.items():\n",
    "        path = os.path.join(Config.DRIVER_DIR, rel_path)\n",
    "        if os.path.exists(path):\n",
    "            with rasterio.open(path) as src:\n",
    "                coords = [(x, y) for x, y in zip(sample_data['x'], sample_data['y'])]\n",
    "                values = [val[0] for val in src.sample(coords)]\n",
    "                df[var_name] = values\n",
    "                valid = np.sum(~np.isnan(values))\n",
    "                print(f\"  ✓ {var_name}: {valid}/{len(values)} valid\")\n",
    "        else:\n",
    "            print(f\"  ✗ {var_name}: NOT FOUND\")\n",
    "    \n",
    "    # Remove NoData rows\n",
    "    original = len(df)\n",
    "    df = df.dropna()\n",
    "    print(f\"\\n  Removed {original - len(df)} rows with NoData\")\n",
    "    print(f\"  Final samples: {len(df)}\")\n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8db774",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 6. VIF ANALYSIS (NEW)\n",
    "# =============================================================================\n",
    "\n",
    "def calculate_vif(df, features):\n",
    "    \"\"\"Calculate Variance Inflation Factor\"\"\"\n",
    "    vif_data = []\n",
    "    for feature in features:\n",
    "        other_features = [f for f in features if f != feature]\n",
    "        X = df[other_features].values\n",
    "        y = df[feature].values\n",
    "        mask = ~(np.isnan(X).any(axis=1) | np.isnan(y))\n",
    "        X, y = X[mask], y[mask]\n",
    "        \n",
    "        if len(X) > 0:\n",
    "            reg = LinearRegression().fit(X, y)\n",
    "            r_squared = reg.score(X, y)\n",
    "            vif = 1 / (1 - r_squared) if r_squared < 1 else np.inf\n",
    "        else:\n",
    "            vif = np.nan\n",
    "        \n",
    "        vif_data.append({'Variable': feature, 'VIF': vif})\n",
    "    \n",
    "    return pd.DataFrame(vif_data)\n",
    "\n",
    "\n",
    "def analyze_multicollinearity(df, features):\n",
    "    \"\"\"Analyze multicollinearity. NOTE: VARS_TO_DROP_VIF is decided a priori\n",
    "    (see Config); this routine is a confirmatory diagnostic, not the decision rule.\"\"\"\n",
    "    print(\"\\n--- Multicollinearity Analysis ---\")\n",
    "    \n",
    "    # Correlation check\n",
    "    corr = df[features].corr()\n",
    "    print(\"\\n  Highly correlated pairs (|r| > 0.7):\")\n",
    "    for i in range(len(features)):\n",
    "        for j in range(i+1, len(features)):\n",
    "            if abs(corr.iloc[i, j]) > 0.7:\n",
    "                print(f\"    {features[i]} - {features[j]}: r = {corr.iloc[i, j]:.2f}\")\n",
    "    \n",
    "    # VIF calculation\n",
    "    print(\"\\n  Variance Inflation Factors:\")\n",
    "    vif_df = calculate_vif(df, features)\n",
    "    for _, row in vif_df.sort_values('VIF', ascending=False).iterrows():\n",
    "        status = \" ⚠ HIGH\" if row['VIF'] > 10 else (\" (moderate)\" if row['VIF'] > 5 else \"\")\n",
    "        print(f\"    {row['Variable']}: {row['VIF']:.2f}{status}\")\n",
    "    \n",
    "    print(f\"\\n  Variables to drop: {Config.VARS_TO_DROP_VIF}\")\n",
    "    \n",
    "    return vif_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "836bdfb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 7. SPATIAL BLOCKING FOR TRAIN-TEST SPLIT (NEW)\n",
    "# =============================================================================\n",
    "\n",
    "def create_spatial_blocks(df, n_blocks=25):\n",
    "    \"\"\"Create spatial blocks for train-test splitting\"\"\"\n",
    "    n_side = int(np.sqrt(n_blocks))\n",
    "    x_bins = np.linspace(df['x'].min(), df['x'].max(), n_side + 1)\n",
    "    y_bins = np.linspace(df['y'].min(), df['y'].max(), n_side + 1)\n",
    "    \n",
    "    x_block = np.digitize(df['x'], x_bins[:-1]) - 1\n",
    "    y_block = np.digitize(df['y'], y_bins[:-1]) - 1\n",
    "    \n",
    "    return x_block + y_block * n_side\n",
    "\n",
    "\n",
    "def spatial_train_test_split(df, test_size=0.2, n_blocks=25, random_state=42):\n",
    "    \"\"\"Train-test split using spatial blocking\"\"\"\n",
    "    print(\"\\n--- Spatial Train-Test Split ---\")\n",
    "    \n",
    "    block_ids = create_spatial_blocks(df, n_blocks)\n",
    "    df['spatial_block'] = block_ids\n",
    "    \n",
    "    unique_blocks = np.unique(block_ids)\n",
    "    np.random.seed(random_state)\n",
    "    n_test = max(1, int(len(unique_blocks) * test_size))\n",
    "    test_blocks = np.random.choice(unique_blocks, n_test, replace=False)\n",
    "    \n",
    "    test_mask = df['spatial_block'].isin(test_blocks)\n",
    "    train_df = df[~test_mask].copy()\n",
    "    test_df = df[test_mask].copy()\n",
    "    \n",
    "    print(f\"  Spatial blocks: {len(unique_blocks)}\")\n",
    "    print(f\"  Test blocks: {len(test_blocks)}\")\n",
    "    print(f\"  Train: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)\")\n",
    "    print(f\"  Test: {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)\")\n",
    "    \n",
    "    return train_df, test_df\n",
    "\n",
    "\n",
    "def standard_train_test_split(df, test_size=0.2, random_state=42):\n",
    "    \"\"\"Standard stratified train-test split\"\"\"\n",
    "    print(\"\\n--- Standard Train-Test Split ---\")\n",
    "    train_df, test_df = train_test_split(df, test_size=test_size, \n",
    "                                          random_state=random_state,\n",
    "                                          stratify=df['habitat_loss'])\n",
    "    print(f\"  Train: {len(train_df)}, Test: {len(test_df)}\")\n",
    "    return train_df, test_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1dcb211",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# 8. SAVE OUTPUTS\n",
    "# =============================================================================\n",
    "\n",
    "def save_outputs(df, train_df, test_df, features, vif_df, sample_data):\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"SAVING OUTPUTS\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    samples_dir = os.path.join(Config.OUTPUT_DIR, 'samples')\n",
    "    metadata_dir = os.path.join(Config.OUTPUT_DIR, 'metadata')\n",
    "    \n",
    "    # Shapefile\n",
    "    geometry = [Point(x, y) for x, y in zip(df['x'], df['y'])]\n",
    "    gdf = gpd.GeoDataFrame(df, geometry=geometry, crs=sample_data['crs'])\n",
    "    gdf.to_file(os.path.join(samples_dir, 'sample_points.shp'))\n",
    "    print(f\"  ✓ Sample points shapefile saved\")\n",
    "    \n",
    "    # CSVs\n",
    "    df.to_csv(os.path.join(samples_dir, 'driver_data_full.csv'), index=False)\n",
    "    train_df.to_csv(os.path.join(samples_dir, 'train_data.csv'), index=False)\n",
    "    test_df.to_csv(os.path.join(samples_dir, 'test_data.csv'), index=False)\n",
    "    print(f\"  ✓ Train/test CSVs saved\")\n",
    "    \n",
    "    # Correlation matrix\n",
    "    df[features].corr().to_csv(os.path.join(metadata_dir, 'correlation_matrix.csv'))\n",
    "    \n",
    "    # VIF\n",
    "    vif_df.to_csv(os.path.join(metadata_dir, 'vif_analysis.csv'), index=False)\n",
    "    \n",
    "    # Metadata\n",
    "    metadata = {\n",
    "        'timestamp': datetime.now().isoformat(),\n",
    "        'total_samples': len(df),\n",
    "        'train_samples': len(train_df),\n",
    "        'test_samples': len(test_df),\n",
    "        'natural_classes': Config.NATURAL_CLASSES,\n",
    "        'spatial_blocking': Config.USE_SPATIAL_BLOCKING,\n",
    "        'features_used': features,\n",
    "        'features_dropped': Config.VARS_TO_DROP_VIF,\n",
    "    }\n",
    "    with open(os.path.join(metadata_dir, 'sample_metadata.json'), 'w') as f:\n",
    "        json.dump(metadata, f, indent=2)\n",
    "    print(f\"  ✓ Metadata saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30bd47fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e3a3015",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "# =============================================================================\n",
    "# 9. MAIN EXECUTION\n",
    "# =============================================================================\n",
    "\n",
    "def main():\n",
    "    print(\"=\"*60)\n",
    "    print(\"SCRIPT 2: SAMPLE GENERATION (REVISED)\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    print(f\"\\nConfiguration:\")\n",
    "    print(f\"  - Natural classes: {Config.NATURAL_CLASSES} (INCLUDES WATER)\")\n",
    "    print(f\"  - Spatial blocking: {Config.USE_SPATIAL_BLOCKING}\")\n",
    "    print(f\"  - Variables to drop (VIF): {Config.VARS_TO_DROP_VIF}\")\n",
    "    \n",
    "    create_output_directories()\n",
    "    \n",
    "    # Create loss mask\n",
    "    mask_data = create_habitat_loss_mask()\n",
    "    \n",
    "    # Generate samples\n",
    "    sample_data = generate_stratified_samples(mask_data)\n",
    "    \n",
    "    # Extract driver values\n",
    "    df = extract_driver_values(sample_data)\n",
    "    \n",
    "    # Define features (all drivers)\n",
    "    all_features = [f for f in Config.DRIVER_FILES.keys() if f in df.columns]\n",
    "    \n",
    "    # VIF analysis\n",
    "    vif_df = analyze_multicollinearity(df, all_features)\n",
    "    \n",
    "    # Drop high-VIF variables\n",
    "    features = [f for f in all_features if f not in Config.VARS_TO_DROP_VIF]\n",
    "    print(f\"\\n  Features after VIF filtering: {len(features)}\")\n",
    "    \n",
    "    # Train-test split\n",
    "    if Config.USE_SPATIAL_BLOCKING:\n",
    "        train_df, test_df = spatial_train_test_split(df, Config.TEST_SIZE, \n",
    "                                                      Config.N_SPATIAL_BLOCKS,\n",
    "                                                      Config.RANDOM_STATE)\n",
    "    else:\n",
    "        train_df, test_df = standard_train_test_split(df, Config.TEST_SIZE, \n",
    "                                                       Config.RANDOM_STATE)\n",
    "    \n",
    "    # Save outputs\n",
    "    save_outputs(df, train_df, test_df, features, vif_df, sample_data)\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"SAMPLE GENERATION COMPLETE\")\n",
    "    print(\"=\"*60)\n",
    "    print(f\"  Total samples: {len(df)}\")\n",
    "    print(f\"  Train: {len(train_df)}, Test: {len(test_df)}\")\n",
    "    print(f\"  Features: {len(features)} (dropped: {Config.VARS_TO_DROP_VIF})\")\n",
    "    \n",
    "    return {'df': df, 'train_df': train_df, 'test_df': test_df, 'features': features}\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    results = main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
