#=============================================================================
# PROLONG (PRobabilistic mOdeL Of techNology Growth)
#=============================================================================

#--------------------
# 1. Setup and Dependencies
#--------------------

suppressPackageStartupMessages({
  required_packages <- c(
    "tidyverse", "yaml", "parallel", "foreach", "doParallel", "ranger", "mgcv", "purrr",
    "splines", "ggplot2", "patchwork", "cluster", "fs", "lhs", "Rtsne"
  )
  
  for (pkg in required_packages) {
    if (!requireNamespace(pkg, quietly = TRUE)) {
      install.packages(pkg)
    }
    library(pkg, character.only = TRUE)
  }
})

#--------------------
# 2. Parameter Distribution Generation Functions
#--------------------

#' @param base_config Base configuration to modify
#' @param k_divisions Number of divisions for growth rate parameter
#' @param L_divisions Number of divisions for saturation level parameter
#' @param takeoff_divisions Number of divisions for takeoff year parameter 
#' @param add_face_centered Whether to add face-centered points
#' @return List with configurations and metadata

create_hybrid_coverage <- function(base_config,
                                   k_divisions = 3,
                                   L_divisions = 3, 
                                   takeoff_divisions = 2,
                                   add_face_centered = TRUE) {
  # Input validation
  if (!is.list(base_config) || is.null(base_config$parameter_ranges)) {
    stop("Invalid base_config: missing parameter_ranges")
  }
  
  # Check required parameter ranges exist
  required_params <- c("k", "L", "takeoff")
  for (param in required_params) {
    if (is.null(base_config$parameter_ranges[[param]])) {
      stop(paste("Missing parameter range for", param))
    }
  }
  
  # Extract base parameters with safety checks
  k_shape <- base_config$parameter_ranges$k$shape
  k_rate <- base_config$parameter_ranges$k$rate
  k_base_mean <- k_shape / k_rate
  
  L_shape <- base_config$parameter_ranges$L$shape
  L_rate <- base_config$parameter_ranges$L$rate
  L_base_mean <- L_shape / L_rate
  
  takeoff_mean <- base_config$parameter_ranges$takeoff$mean
  
  # Define the parameter ranges with safety boundaries
  k_mean_range <- c(max(0.05, k_base_mean * 0.1), k_base_mean * 1.25)
  L_mean_range <- c(max(0.05, base_config$parameter_ranges$L$max*0.05, L_base_mean), 
                    base_config$parameter_ranges$L$max*0.75)
  takeoff_mean_range <- c(max(1, takeoff_mean), takeoff_mean+10)
  
  # 1. Create main grid points
  
  # Calculate grid positions for k
  k_positions <- seq(1/(2*k_divisions), 1-1/(2*k_divisions), length.out = k_divisions)
  k_new_means <- qunif(k_positions, k_mean_range[1], k_mean_range[2])
  
  # Calculate grid positions for L
  L_positions <- seq(1/(2*L_divisions), 1-1/(2*L_divisions), length.out = L_divisions)
  L_new_means <- qunif(L_positions, L_mean_range[1], L_mean_range[2])
  
  # Calculate grid positions for takeoff
  takeoff_positions <- seq(1/(2*takeoff_divisions), 1-1/(2*takeoff_divisions), length.out = takeoff_divisions)
  takeoff_new_means <- qunif(takeoff_positions, takeoff_mean_range[1], takeoff_mean_range[2])
  
  # Create full factorial design for the grid
  grid_points <- expand.grid(
    k_idx = 1:k_divisions,
    L_idx = 1:L_divisions,
    takeoff_idx = 1:takeoff_divisions
  )
  
  # Transform grid points to parameter values
  grid_params <- data.frame(
    k_new_mean = k_new_means[grid_points$k_idx],
    L_new_mean = L_new_means[grid_points$L_idx],
    takeoff_new_mean = takeoff_new_means[grid_points$takeoff_idx],
    variation_type = "grid"
  )
  
  # 2. Add face-centered points
  face_centered_params <- data.frame()
  
  if (add_face_centered && k_divisions >= 2 && L_divisions >= 2) {
    # For each pair of adjacent k and L values, add a center point
    for (k_i in 1:(k_divisions-1)) {
      for (L_i in 1:(L_divisions-1)) {
        # Calculate midpoint
        mid_k <- (k_new_means[k_i] + k_new_means[k_i+1]) / 2
        mid_L <- (L_new_means[L_i] + L_new_means[L_i+1]) / 2
        
        # For each takeoff value, add a face-centered point
        for (t_i in 1:takeoff_divisions) {
          face_centered_params <- rbind(face_centered_params, data.frame(
            k_new_mean = mid_k,
            L_new_mean = mid_L,
            takeoff_new_mean = takeoff_new_means[t_i],
            variation_type = "face_centered"
          ))
        }
      }
    }
  }
  
  # 3. Combine all points
  all_params <- rbind(grid_params, face_centered_params)
  
  # 4. Create configurations from parameters with validation
  configs <- list()
  for (i in 1:nrow(all_params)) {
    new_config <- base_config
    
    # Safety check for k rate
    k_new_rate <- k_shape / all_params$k_new_mean[i]
    # Ensure rate isn't too small or too large
    k_new_rate <- max(0.1, min(100, k_new_rate))
    new_config$parameter_ranges$k$rate <- k_new_rate
    
    # Safety check for L rate
    L_new_rate <- L_shape / all_params$L_new_mean[i]
    # Ensure rate isn't too small or too large
    L_new_rate <- max(0.1, min(100, L_new_rate))
    new_config$parameter_ranges$L$rate <- L_new_rate
    
    # Set takeoff parameters with bounds
    new_config$parameter_ranges$takeoff$mean <- all_params$takeoff_new_mean[i]
    # Ensure min and max are consistent with mean
    new_config$parameter_ranges$takeoff$min <- max(1, new_config$parameter_ranges$takeoff$mean - 5) 
    new_config$parameter_ranges$takeoff$max <- new_config$parameter_ranges$takeoff$mean + 5
    
    configs[[i]] <- new_config
  }
  
  # Return results
  return(list(
    configurations = configs,
    summary = all_params,
    n_grid = nrow(grid_params),
    n_face_centered = nrow(face_centered_params),
    total_configs = nrow(all_params)
  ))
}

# Function to visualize trajectory coverage
plot_trajectory_coverage <- function(configs) {
  # Generate trajectories
  years <- 1:50
  trajectories <- matrix(0, nrow = length(years), ncol = length(configs))
  
  for (i in seq_along(configs)) {
    config <- configs[[i]]
    
    # Extract mean parameters
    k <- config$parameter_ranges$k$shape / config$parameter_ranges$k$rate
    L <- config$parameter_ranges$L$shape / config$parameter_ranges$L$rate
    takeoff <- config$parameter_ranges$takeoff$mean
    
    # Calculate inflection point
    t0 <- takeoff + (1/k) * log(L/0.01 - 1)
    
    # Generate logistic curve
    trajectories[, i] <- L / (1 + exp(-k * (years - t0)))
  }
  
  # Convert to data frame for plotting
  plot_data <- data.frame(year = rep(years, times = length(configs)),
                          config = rep(1:length(configs), each = length(years)),
                          deployment = as.vector(trajectories))
  
  # Create plot with ggplot2
  library(ggplot2)
  
  p <- ggplot(plot_data, aes(x = year, y = deployment, group = config)) +
    geom_line(alpha = 0.5, color = "blue") +
    labs(title = "Trajectory Coverage from Parameter Configurations",
         x = "Year", 
         y = "Global Deployment") +
    theme_minimal() +
    ylim(0, max(trajectories) * 1.05)
  
  return(p)
}

visualize_parameter_coverage <- function(summary_data) {
  # Check if input is a list (like from create_hybrid_coverage) or a data frame
  if (is.list(summary_data) && !is.data.frame(summary_data)) {
    # Extract the summary data frame from the list
    summary_data <- summary_data$summary
  }
  
  # Load required packages
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    install.packages("ggplot2")
  }
  library(ggplot2)
  
  # Create color and shape mappings for different point types
  # Update the color and shape mappings
  point_colors <- c("grid" = "steelblue", "face_centered" = "darkred", 
                    "base" = "darkgreen", "corner" = "purple")
  point_shapes <- c("grid" = 16, "face_centered" = 17, "base" = 15, "corner" = 18)
  point_sizes <- c("grid" = 3, "face_centered" = 4, "base" = 5, "corner" = 4)
  
  # 1. Create main parameter coverage plot
  parameter_plot <- ggplot(summary_data, 
                           aes(x = k_new_mean, 
                               y = L_new_mean, 
                               color = variation_type,
                               shape = variation_type,
                               size = variation_type)) +
    geom_point() +
    geom_text(aes(label = 1:nrow(summary_data)), 
              color = "white", 
              size = 3, 
              vjust = 0.35,
              hjust = 0.5) +
    scale_color_manual(values = point_colors, name = "Point Type") +
    scale_shape_manual(values = point_shapes, name = "Point Type") +
    scale_size_manual(values = point_sizes, name = "Point Type") +
    labs(
      title = "Parameter distribution mean coverage",
      subtitle = paste0("Displaying ", nrow(summary_data), " configurations with hybrid coverage strategy"),
      x = "Growth Rate (k)",
      y = "Saturation Level (L)"
    ) +
    theme_minimal()
  
  # 2. Color by takeoff year instead of point type
  takeoff_plot <- ggplot(summary_data, 
                         aes(x = k_new_mean, 
                             y = L_new_mean, 
                             color = takeoff_new_mean,
                             shape = variation_type,
                             size = variation_type)) +
    geom_point() +
    scale_color_viridis_c(name = "Takeoff\nYear") +
    scale_shape_manual(values = point_shapes, name = "Point Type") +
    scale_size_manual(values = point_sizes, name = "Point Type") +
    labs(
      title = "Parameter Coverage for Takeoff Year",
      x = "Growth Rate (k)",
      y = "Saturation Level (L)"
    ) +
    theme_minimal()
  
  # 3. Create 2D projections of parameter space
  k_L_plot <- ggplot(summary_data, aes(x = k_new_mean, y = L_new_mean)) +
    geom_point(aes(color = variation_type, shape = variation_type, size = variation_type)) +
    geom_text(aes(label = 1:nrow(summary_data)), 
              color = "black", 
              size = 3, 
              vjust = -1.2) +
    scale_color_manual(values = point_colors, name = "Point Type") +
    scale_shape_manual(values = point_shapes, name = "Point Type") +
    scale_size_manual(values = point_sizes, name = "Point Type") +
    labs(
      title = "k vs L Coverage",
      x = "Growth Rate (k)",
      y = "Saturation Level (L)"
    ) +
    theme_minimal()
  
  k_takeoff_plot <- ggplot(summary_data, aes(x = k_new_mean, y = takeoff_new_mean)) +
    geom_point(aes(color = variation_type, shape = variation_type, size = variation_type)) +
    geom_text(aes(label = 1:nrow(summary_data)), 
              color = "black", 
              size = 3, 
              vjust = -1.2) +
    scale_color_manual(values = point_colors, name = "Point Type") +
    scale_shape_manual(values = point_shapes, name = "Point Type") +
    scale_size_manual(values = point_sizes, name = "Point Type") +
    labs(
      title = "k vs Takeoff Coverage",
      x = "Growth Rate (k)",
      y = "Takeoff Year"
    ) +
    theme_minimal()
  
  L_takeoff_plot <- ggplot(summary_data, aes(x = L_new_mean, y = takeoff_new_mean)) +
    geom_point(aes(color = variation_type, shape = variation_type, size = variation_type)) +
    geom_text(aes(label = 1:nrow(summary_data)), 
              color = "black", 
              size = 3, 
              vjust = -1.2) +
    scale_color_manual(values = point_colors, name = "Point Type") +
    scale_shape_manual(values = point_shapes, name = "Point Type") +
    scale_size_manual(values = point_sizes, name = "Point Type") +
    labs(
      title = "L vs Takeoff Coverage",
      x = "Saturation Level (L)",
      y = "Takeoff Year"
    ) +
    theme_minimal()
  
  # 4. Visualize expected trajectories (simplified preview)
  # Generate representative trajectories based on configurations
  preview_trajectories <- function() {
    years <- 1:50
    curves <- data.frame()
    
    for (i in 1:nrow(summary_data)) {
      k <- summary_data$k_new_mean[i]
      L <- summary_data$L_new_mean[i]
      takeoff <- summary_data$takeoff_new_mean[i]
      
      # Calculate inflection point
      t0 <- takeoff + (1/k) * log(L/0.01 - 1)
      
      # Generate logistic curve
      y <- L / (1 + exp(-k * (years - t0)))
      
      # Store in data frame
      curves <- rbind(curves, data.frame(
        year = years,
        deployment = y,
        config_id = i,
        variation_type = summary_data$variation_type[i]
      ))
    }
    
    return(curves)
  }
  
  trajectory_preview <- tryCatch({
    curves_data <- preview_trajectories()
    
    ggplot(curves_data, aes(x = year, y = deployment, 
                            group = config_id, 
                            color = variation_type)) +
      geom_line(alpha = 0.7) +
      scale_color_manual(values = point_colors, name = "Point Type") +
      labs(
        title = "Trajectory Preview Based on Configurations",
        x = "Year",
        y = "Global Deployment",
        color = "Config Type"
      ) +
      theme_minimal()
  }, error = function(e) {
    message("Could not generate trajectory preview: ", e$message)
    NULL
  })
  
  # Return all plots
  return(list(
    parameter_plot = parameter_plot,
    takeoff_plot = takeoff_plot,
    k_L_plot = k_L_plot,
    k_takeoff_plot = k_takeoff_plot,
    L_takeoff_plot = L_takeoff_plot,
    trajectory_preview = trajectory_preview,
    data = summary_data
  ))
}


#' Visualize parameter distributions
#' 
#' Generates plots to visualize the parameter distributions and their relationships
#' 
#' @param variations Output from create_distribution_variations
#' @return List of plots and visualization data
visualize_parameter_distributions <- function(variations) {
  summary_data <- variations$summary
  
  # Create visualization data for each distribution
  viz_data <- list()
  
  # Select a subset of variations to visualize (to avoid overcrowding)
  n_to_visualize <- min(6, nrow(summary_data))
  indices_to_visualize <- round(seq(1, nrow(summary_data), length.out = n_to_visualize))
  viz_subset <- summary_data[indices_to_visualize, ]
  
  # Generate k distribution curves
  k_curves <- lapply(1:nrow(viz_subset), function(i) {
    x <- seq(0, max(viz_subset$k_new_mean) * 2, length.out = 100)
    y <- dgamma(x, shape = viz_subset$k_shape[i], rate = viz_subset$k_new_rate[i])
    tibble(
      x = x,
      density = y,
      variation_id = viz_subset$variation_id[i],
      param = "k",
      mean = viz_subset$k_new_mean[i],
      scenario = viz_subset$scenario[i]
    )
  })
  
  # Generate L distribution curves
  L_curves <- lapply(1:nrow(viz_subset), function(i) {
    x <- seq(0, max(viz_subset$L_new_mean) * 2, length.out = 100)
    y <- dgamma(x, shape = viz_subset$L_shape[i], rate = viz_subset$L_new_rate[i])
    tibble(
      x = x,
      density = y,
      variation_id = viz_subset$variation_id[i],
      param = "L",
      mean = viz_subset$L_new_mean[i],
      scenario = viz_subset$scenario[i]
    )
  })
  
  # Generate takeoff distribution curves
  takeoff_curves <- lapply(1:nrow(viz_subset), function(i) {
    x <- seq(max(0, viz_subset$takeoff_new_mean[i] - 4*viz_subset$takeoff_sd[i]),
             viz_subset$takeoff_new_mean[i] + 4*viz_subset$takeoff_sd[i],
             length.out = 100)
    y <- dnorm(x, mean = viz_subset$takeoff_new_mean[i], sd = viz_subset$takeoff_sd[i])
    tibble(
      x = x,
      density = y,
      variation_id = viz_subset$variation_id[i],
      param = "takeoff",
      mean = viz_subset$takeoff_new_mean[i],
      scenario = viz_subset$scenario[i]
    )
  })
  
  # Combine all curves
  all_curves <- bind_rows(k_curves, L_curves, takeoff_curves)
  
  # Plot distributions
  param_labels <- c(
    "k" = "Growth Rate (k) Distribution",
    "L" = "Saturation Level (L) Distribution",
    "takeoff" = "Takeoff Year Distribution"
  )
  
  distribution_plot <- ggplot(all_curves, aes(x = x, y = density, color = factor(variation_id), group = variation_id)) +
    geom_line(linewidth = 1) +
    geom_vline(aes(xintercept = mean, color = factor(variation_id)), linetype = "dashed") +
    facet_wrap(~ param, scales = "free", labeller = labeller(param = param_labels)) +
    labs(
      title = "Parameter Distributions Across Variations",
      subtitle = "Dashed lines indicate distribution means",
      x = "Parameter Value",
      y = "Density",
      color = "Variation ID"
    ) +
    theme_minimal() +
    theme(
      legend.position = "bottom",
      panel.grid.minor = element_blank()
    )
  
  # Create scatter plot to show relationship between parameters
  scatter_data <- summary_data %>%
    select(variation_id, k_new_mean, L_new_mean, takeoff_new_mean, scenario)
  
  param_relationships <- ggplot(scatter_data) +
    geom_point(aes(x = k_new_mean, y = L_new_mean, color = takeoff_new_mean, label = variation_id), size = 4) +
    geom_text(aes(x = k_new_mean, y = L_new_mean, label = variation_id), size = 3, hjust = 0.5, vjust = -1) +
    scale_color_viridis_c(name = "Takeoff Year") +
    labs(
      title = "Parameter Relationships Across Variations",
      x = "Growth Rate (k)",
      y = "Saturation Level (L)"
    ) +
    theme_minimal()
  
  # Generate simulated growth curves based on these parameters
  years <- 0:50
  growth_curves <- lapply(1:nrow(summary_data), function(i) {
    k <- summary_data$k_new_mean[i]
    L <- summary_data$L_new_mean[i]
    t0 <- summary_data$takeoff_new_mean[i] + (1/k) * log(L/0.01 - 1)
    
    # Generate logistic curve
    y <- L / (1 + exp(-k * (years - t0)))
    
    tibble(
      year = years,
      deployment = y,
      variation_id = summary_data$variation_id[i],
      scenario = summary_data$scenario[i],
      k = k,
      L = L,
      t0 = t0
    )
  })
  
  all_growth_curves <- bind_rows(growth_curves)
  
  growth_plot <- ggplot(all_growth_curves, aes(x = year, y = deployment, color = factor(variation_id), group = variation_id)) +
    geom_line(linewidth = 1) +
    labs(
      title = "Simulated Growth Curves from Parameter Variations",
      x = "Year",
      y = "Deployment Level",
      color = "Variation ID"
    ) +
    theme_minimal() +
    theme(legend.position = "bottom")
  
  # Return plots and data
  return(list(
    distribution_plot = distribution_plot,
    param_relationships = param_relationships,
    growth_plot = growth_plot,
    distribution_data = all_curves,
    growth_data = all_growth_curves
  ))
}

#' Generate growth parameters for a country simulation
#' 
#' @param model_type Type of growth model
#' @param config Configuration with parameter distributions
#' @return List of growth parameters
generate_growth_parameters <- function(model_type, config, 
                                       max_tries = 100,
                                       warn_on_bound_clustering = TRUE) {
  
  # Helper function for rejection sampling within bounds
  sample_within_bounds <- function(sampler_fn, min_val, max_val, parameter_name) {
    for (i in 1:max_tries) {
      value <- sampler_fn()
      if (value >= min_val && value <= max_val) {
        return(value)
      }
    }
    
    # If we couldn't get a value within bounds after max_tries, check how far the distribution
    # extends outside the bounds to determine if clustering is likely
    test_sample <- replicate(1000, sampler_fn())
    pct_below <- mean(test_sample < min_val) * 100
    pct_above <- mean(test_sample > max_val) * 100
    
    if (warn_on_bound_clustering && (pct_below > 5 || pct_above > 5)) {
      warning(sprintf("Distribution for %s has %.1f%% probability below min and %.1f%% above max, ",
                      parameter_name, pct_below, pct_above),
              "suggesting potential clustering at bounds. Consider adjusting distribution parameters.")
    }
    
    # Return the min or max, with a slight randomization to avoid perfect clustering
    # This adds a tiny bit of randomness within 1% of the bound
    if (pct_below > pct_above) {
      return(min_val * (1 + runif(1, 0, 0.01)))
    } else {
      return(max_val * (1 - runif(1, 0, 0.01)))
    }
  }
  
  # Create samplers for each parameter
  L_sampler <- function() {
    rgamma(1, shape = config$parameter_ranges$L$shape, 
           rate = config$parameter_ranges$L$rate)
  }
  
  k_sampler <- function() {
    rgamma(1, shape = config$parameter_ranges$k$shape, 
           rate = config$parameter_ranges$k$rate)
  }
  
  takeoff_sampler <- function() {
    rnorm(1, mean = config$parameter_ranges$takeoff$mean,
          sd = config$parameter_ranges$takeoff$sd)
  }
  
  # Sample parameters using rejection sampling
  params <- list(
    L = sample_within_bounds(
      L_sampler, 
      config$parameter_ranges$L$min,
      config$parameter_ranges$L$max,
      "L"
    ),
    
    k = sample_within_bounds(
      k_sampler, 
      config$parameter_ranges$k$min,
      config$parameter_ranges$k$max,
      "k"
    ),
    
    takeoff_year = sample_within_bounds(
      takeoff_sampler, 
      config$parameter_ranges$takeoff$min,
      config$parameter_ranges$takeoff$max,
      "takeoff_year"
    )
  )
  
  # Calculate inflection point for base parameters
  params$t0 <- round(params$takeoff_year + (1/params$k) * log(params$L/0.01 - 1), digits = 2)
  
  # For bilogistic, add validation
  if(model_type == "bilogistic") {
    # First phase
    L1_fraction <- runif(1, 0.20, 0.80)
    params$L1 <- params$L * L1_fraction
    params$k1 <- params$k
    params$takeoff_year1 <- params$takeoff_year
    
    # VALIDATION CHECK 1: Ensure L1 is within reasonable bounds
    if (params$L1 <  config$parameter_ranges$L$min*1.1) {  # Slightly higher than 0.01 minimum
      message("Invalid L1 value < 0.015, regenerating parameters")
      return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    }
    
    # Calculate first inflection point
    params$t01 <- round(params$takeoff_year1 + (1/params$k1) * log(params$L1/0.01 - 1), digits = 2)
    
    # VALIDATION CHECK 2: Reasonable t01 value
    if (params$t01 < 0) {
      message("Invalid t01 outside simulation range, regenerating parameters")
      return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    }
    
    # Calculate first growth rate
    G1 <- (params$L1 * params$k1) / 4
    
    # Second phase parameters
    # Use more conservative range for delay
    delay <- runif(1, 1, 10)  # Reduced upper bound from 15
    
    
    params$takeoff_year2 <- params$t01 + delay
    
    # # VALIDATION CHECK 3: Second phase starts within simulation timeframe
    # if (params$takeoff_year2 > config$n_years - 5) {
    #   message("Second phase starts too late, regenerating parameters")
    #   return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    # }
    
    params$L2 <- params$L * (1 - L1_fraction)
    
    # VALIDATION CHECK 4: Ensure L2 is reasonable
    if (params$L2 <  config$parameter_ranges$L$min*1.1) {
      message("Invalid L2 value < 0.015, regenerating parameters")
      return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    }
    
    
    G2 <- G1 * rlnorm(1, meanlog = -0.35^2/2, sdlog = 0.35)
    
    
    # Calculate k2
    params$k2 <- (4 * G2) / params$L2
    
    # VALIDATION CHECK 5: k2 within reasonable bounds
    if (params$k2 < 0.05 || params$k2 > 1.5) {
      message("Invalid k2 value outside reasonable range, regenerating parameters")
      return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    }
    
    # Calculate second inflection point
    params$t02 <- round(params$takeoff_year2 + (1/params$k2) * log(params$L2/0.01 - 1), digits = 2)
    
    # # VALIDATION CHECK 6: Reasonable t02 value
    # if (params$t02 < 0 || params$t02 > config$n_years - 5) {
    #   message("Invalid t02 outside simulation range, regenerating parameters")
    #   return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    # }
    
    # Store derived parameters
    params$G1 <- G1
    params$G2 <- G2
    
    # # VALIDATION CHECK 7: Final combined saturation check
    # if (params$L1 + params$L2 >  config$parameter_ranges$L$max) {
    #   message("Combined saturation exceeds 1.0, regenerating parameters")
    #   return(generate_growth_parameters(model_type, config, max_tries, warn_on_bound_clustering))
    # }
  }
  
  return(params)
}



#--------------------
# 3. Simulation Functions
#--------------------

#' Simulate logistic growth
#' 
#' @param t Vector of time points
#' @param params Growth parameters
#' @param noise_params Parameters controlling noise
#' @return Vector of deployment values
#' 
simulate_logistic <- function(t, params, noise_params = list(
  amplitude = 0.05,     # Base noise amplitude (as fraction of current value)
  correlation = 0.7,    # Year-to-year correlation (0-1)
  decay = 0.98         # How noise amplitude decreases over time
)) {
  L <- params$L
  k <- params$k
  takeoff_year <- round(params$takeoff_year)
  t0 <- params$t0
  
  # Zero growth before takeoff
  y <- rep(0, length(t))
  
  # After takeoff, follow logistic growth
  post_takeoff <- t >= takeoff_year
  if(any(post_takeoff)) {
    # Calculate base logistic values
    y[post_takeoff] <- L / (1 + exp(-k * (t[post_takeoff] - t0)))
    
    # Generate correlated noise
    n_years <- sum(post_takeoff)
    noise <- numeric(n_years)
    
    # First year's noise
    noise[1] <- rnorm(1, 0, noise_params$amplitude)
    
    # Generate subsequent years with correlation
    for(i in 2:n_years) {
      # Decay amplitude over time
      current_amplitude <- noise_params$amplitude * (noise_params$decay ^ (i-1))
      
      # Combine correlated and fresh noise
      noise[i] <- noise_params$correlation * noise[i-1] + 
        (1 - noise_params$correlation) * rnorm(1, 0, current_amplitude)
    }
    
    # Apply noise multiplicatively (as percentage deviation)
    y[post_takeoff] <- y[post_takeoff] * (1 + noise)
    
    # Ensure values stay within reasonable bounds
    y[post_takeoff] <- pmin(pmax(y[post_takeoff], 0), L * 1.1)
  }
  
  return(y)
}

#' Simulate logistic-linear growth
#' 
#' @param t Vector of time points
#' @param params Growth parameters
#' @param noise_params Parameters controlling noise
#' @return Vector of deployment values
simulate_logistic_linear <- function(t, params, noise_params = list(
  amplitude = 0.05,     
  correlation = 0.7,    
  decay = 0.98         
)) {
  L <- params$L
  k <- params$k
  takeoff_year <- round(params$takeoff_year)
  t0 <- params$t0  
  G <- (L * k) / 4
  
  # Zero growth before takeoff
  y <- rep(0, length(t))
  
  # After takeoff
  post_takeoff <- t >= takeoff_year
  if(any(post_takeoff)) {
    # Generate base growth pattern
    # Before inflection point, logistic growth
    pre_t0 <- t[post_takeoff] <= t0
    if(any(pre_t0)) {
      y[post_takeoff][pre_t0] <- L / (1 + exp(-k * (t[post_takeoff][pre_t0] - t0)))
    }
    
    # After inflection point, linear growth
    post_t0 <- t[post_takeoff] > t0
    if(any(post_t0)) {
      y_t0 <- L / 2  
      y[post_takeoff][post_t0] <- pmin(L, y_t0 + G * (t[post_takeoff][post_t0] - t0))
    }
    
    # Generate correlated noise for the whole post-takeoff period
    n_years <- sum(post_takeoff)
    noise <- numeric(n_years)
    noise[1] <- rnorm(1, 0, noise_params$amplitude)
    
    for(i in 2:n_years) {
      current_amplitude <- noise_params$amplitude * (noise_params$decay ^ (i-1))
      noise[i] <- noise_params$correlation * noise[i-1] + 
        (1 - noise_params$correlation) * rnorm(1, 0, current_amplitude)
    }
    
    # Apply noise
    y[post_takeoff] <- y[post_takeoff] * (1 + noise)
    
    # Ensure values stay within bounds
    y[post_takeoff] <- pmin(pmax(y[post_takeoff], 0), L * 1.1)
  }
  
  return(y)
}

#' Simulate bilogistic growth
#' 
#' @param t Vector of time points
#' @param params Growth parameters
#' @param noise_params Parameters controlling noise
#' @return Vector of deployment values
simulate_bilogistic <- function(t, params, noise_params = list(
  amplitude = 0.05,     
  correlation = 0.7,    
  decay = 0.98         
)) {
  # Zero growth before first takeoff
  y <- rep(0, length(t))
  
  # Generate base growth pattern
  # First phase
  post_takeoff1 <- t >= params$takeoff_year1
  if(any(post_takeoff1)) {
    y[post_takeoff1] <- params$L1 / (1 + exp(-params$k1 * (t[post_takeoff1] - params$t01)))
  }
  
  # Second phase
  post_takeoff2 <- t >= params$takeoff_year2
  if(any(post_takeoff2)) {
    y2 <- params$L2 / (1 + exp(-params$k2 * (t[post_takeoff2] - params$t02)))
    y[post_takeoff2] <- y[post_takeoff2] + y2
  }
  
  # Apply noise to the entire trajectory after first takeoff
  if(any(post_takeoff1)) {
    n_years <- sum(post_takeoff1)
    noise <- numeric(n_years)
    noise[1] <- rnorm(1, 0, noise_params$amplitude)
    
    for(i in 2:n_years) {
      # Use stronger noise during transition between phases
      current_amplitude <- noise_params$amplitude * (noise_params$decay ^ (i-1))
      
      noise[i] <- noise_params$correlation * noise[i-1] + 
        (1 - noise_params$correlation) * rnorm(1, 0, current_amplitude)
    }
    
    # Apply noise
    y[post_takeoff1] <- y[post_takeoff1] * (1 + noise)
    
    # Ensure values stay within bounds
    y <- pmin(pmax(y, 0), params$L * 1.1)
  }
  
  return(y)
}

#' Simulate country-level technology diffusion
#' 
#' @param years Vector of years
#' @param model_type Type of growth model
#' @param market_size Market size of the country
#' @param config Configuration with parameter distributions
#' @return Data frame with country simulation results
simulate_country <- function(years, model_type, market_size, config) {
  # Generate parameters
  params <- generate_growth_parameters(model_type, config)
  
  # Select growth function
  growth_fn <- switch(model_type,
                      "logistic" = simulate_logistic,
                      "logistic-linear" = simulate_logistic_linear,
                      "bilogistic" = simulate_bilogistic)
  
  # Calculate deployment trajectory
  deployment <- growth_fn(years - min(years), params)
  
  # Create results dataframe
  data.frame(
    year = years,
    deployment = deployment,
    market_size = market_size,
    weighted_deployment = deployment * market_size,
    model_type = model_type,
    stringsAsFactors = FALSE
  ) %>%
    bind_cols(as.data.frame(t(unlist(params))))
}

#' Simulate global technology diffusion
#' 
#' @param market_sizes Vector of country market sizes (shares)
#' @param config Configuration with parameter distributions
#' @return List containing country and global trajectories
simulate_global_growth <- function(market_sizes, config) {
  years <- seq_len(config$n_years)
  n_countries <- length(market_sizes)
  
  # Assign growth models to countries
  country_models <- sample(config$available_models,
                           n_countries,
                           replace = TRUE,   
                           prob = config$model_weights)
  # Simulate each country
  country_sims <- map2_dfr(
    country_models,
    market_sizes,
    ~simulate_country(years, .x, .y, config),
    .id = "country"
  )
  
  # Calculate global trajectory
  global_trajectory <- country_sims %>%
    group_by(year) %>%
    summarise(
      global_deployment = sum(weighted_deployment, na.rm = TRUE),
      n_logistic = sum(model_type == "logistic"),
      n_logistic_linear = sum(model_type == "logistic-linear"),
      n_bilogistic = sum(model_type == "bilogistic"),
      .groups = "drop"
    )
  
  list(
    country_trajectories = country_sims,
    global_trajectory = global_trajectory
  )
}

#--------------------
# 4. Multi-Configuration Monte Carlo Engine
#--------------------

#' Run Monte Carlo simulation with proper unique run IDs
#' 
#' @param config Configuration with parameter distributions
#' @param market_data Market size data
#' @param config_id ID for this configuration (for tracking)
#' @param run_id_offset Starting number for run IDs (to ensure uniqueness)
#' @param max_attempts Maximum number of attempts for failed runs
#' @param target_success_ratio Target ratio of successful to attempted runs
#' @return List with simulation results

run_monte_carlo <- function(config, market_data, config_id = NULL, run_id_offset = 0,
                            max_attempts = 10000, target_success_ratio = 0.9) {
  # Use the parallel cluster already created in run_multi_config_simulations
  
  # Create a chunk function that uses the passed config
  chunk_func <- function(chunk, config, market_data, config_id, run_id_offset) {
    chunk_results <- vector("list", length(chunk))
    for (i in seq_along(chunk)) {
      # Create a truly unique run ID that combines config_id and run number
      local_run_number <- chunk[i]
      global_run_id <- run_id_offset + local_run_number
      
      if (!is.null(config_id)) {
        # Format: configID_globalRunID (e.g., "1_101")
        unique_run_id <- paste(config_id, global_run_id, sep = "_")
      } else {
        unique_run_id <- global_run_id
      }
      
      tryCatch({
        sim <- simulate_global_growth(market_data$market_sizes, config)
        
        # Add config_id if provided
        if (!is.null(config_id)) {
          sim$global_trajectory <- sim$global_trajectory %>% 
            mutate(run = unique_run_id, config_id = config_id)
          
          sim$country_trajectories <- sim$country_trajectories %>% 
            mutate(run = unique_run_id, config_id = config_id)
        } else {
          sim$global_trajectory <- sim$global_trajectory %>% mutate(run = unique_run_id)
          sim$country_trajectories <- sim$country_trajectories %>% mutate(run = unique_run_id)
        }
        
        chunk_results[[i]] <- sim
      }, error = function(e) {
        message(sprintf("Error in run %s: %s", unique_run_id, e$message))
        NULL
      })
    }
    # Remove NULL entries (failed runs)
    chunk_results[!sapply(chunk_results, is.null)]
  }
  
  # Target number of successful runs
  target_runs <- config$n_runs
  max_attempts <- min(max_attempts, target_runs * 5)  # Safety cap to prevent excessive attempts
  
  # PHASE 1: Try the original approach first with a more modest batch size
  # Use fewer chunks to reduce overhead
  n_cores <- parallel::detectCores() - 1
  chunk_size <- ceiling(target_runs / min(n_cores, 16))  # Limit chunks to avoid overhead
  chunks <- split(1:target_runs, ceiling(seq_along(1:target_runs) / chunk_size))
  
  message(sprintf("Starting simulation with %d chunks of size ~%d", length(chunks), chunk_size))
  
  results <- foreach(chunk = chunks,
                     .packages = c("tidyverse"),
                     .combine = 'c', 
                     .multicombine = TRUE,
                     .errorhandling = 'remove') %dopar% {
                       chunk_func(chunk, config, market_data, config_id, run_id_offset)
                     }
  
  # Count successful runs
  successful_runs <- length(results)
  message(sprintf("Initial pass: %d/%d successful runs (%.1f%%)", 
                  successful_runs, target_runs, 100*successful_runs/target_runs))
  
  # PHASE 2: If we don't have enough runs, use a more controlled retry approach
  if (successful_runs < target_runs) {
    message("Insufficient successful runs, continuing with additional attempts...")
    
    # Keep track of attempts to avoid duplicating run IDs
    current_offset <- run_id_offset + target_runs
    remaining_runs <- target_runs - successful_runs
    attempts_made <- target_runs
    
    # Use a fixed batch size for retries rather than the dynamic calculation
    fixed_batch_size <- min(remaining_runs * 2, 500)  # Reasonable fixed size
    
    while (successful_runs < target_runs && attempts_made < max_attempts) {
      # Use the fixed batch size
      batch_size <- min(fixed_batch_size, max_attempts - attempts_made)
      
      message(sprintf("Attempting %d additional runs", batch_size))
      
      # Create new chunks - smaller chunks for better load balancing
      new_chunk_size <- ceiling(batch_size / min(n_cores, 8))
      new_chunks <- split(1:batch_size, ceiling(seq_along(1:batch_size) / new_chunk_size))
      
      additional_results <- foreach(chunk = new_chunks,
                                    .packages = c("tidyverse"),
                                    .combine = 'c', 
                                    .multicombine = TRUE,
                                    .errorhandling = 'remove') %dopar% {
                                      chunk_func(chunk, config, market_data, config_id, current_offset)
                                    }
      
      # Add new results and update tracking
      new_successful <- length(additional_results)
      results <- c(results, additional_results)
      successful_runs <- successful_runs + new_successful
      remaining_runs <- target_runs - successful_runs
      attempts_made <- attempts_made + batch_size
      current_offset <- current_offset + batch_size
      
      message(sprintf("Progress: %d/%d successful runs (%.1f%%)", 
                      successful_runs, target_runs, 100*successful_runs/target_runs))
      
      # Check if we're making reasonable progress
      current_success_ratio <- new_successful / batch_size
      if (current_success_ratio < target_success_ratio / 3) {
        message(sprintf("Low success rate (%.1f%%), may indicate configuration issues. Stopping retries.", 
                        100*current_success_ratio))
        break
      }
      
      # Break if we've reached the target or max attempts
      if (successful_runs >= target_runs || attempts_made >= max_attempts) break
      
      # Force garbage collection after each batch
      gc(verbose = FALSE)
    }
    
    # Trim excess successful runs if we have more than requested
    if (successful_runs > target_runs) {
      message(sprintf("Trimming excess runs: %d → %d", successful_runs, target_runs))
      results <- results[1:target_runs]
      successful_runs <- target_runs
    }
  }
  
  # Process results
  global_results <- bind_rows(lapply(results, function(x) x$global_trajectory))
  country_results <- bind_rows(lapply(results, function(x) x$country_trajectories))
  
  # Verify uniqueness of run IDs
  unique_runs <- unique(global_results$run)
  years_per_run <- length(unique(global_results$year))
  expected_unique_runs <- nrow(global_results) / years_per_run
  
  if (length(unique_runs) != expected_unique_runs) {
    warning("Run IDs are not unique within global_results!")
    message(paste("Unique run IDs:", length(unique_runs)))
    message(paste("Expected unique runs:", expected_unique_runs))
  }
  
  # Calculate summary statistics
  summary_stats <- global_results %>%
    group_by(year) %>%
    summarise(
      mean_deployment = mean(global_deployment, na.rm = TRUE),
      median_deployment = median(global_deployment, na.rm = TRUE),
      lower_95 = quantile(global_deployment, 0.05, na.rm = TRUE),
      upper_95 = quantile(global_deployment, 0.95, na.rm = TRUE),
      .groups = "drop"
    )
  
  list(
    summary = summary_stats,
    global_results = global_results,
    country_results = country_results,
    n_successful_runs = length(unique_runs)
  )
}

#' Run simulations across multiple configurations with guaranteed unique run IDs
#' 
#' @param configs List of configuration objects
#' @param market_data Market size data
#' @param runs_per_config Number of runs for each configuration
#' @param batch_size Number of configurations to process in each batch
#' @param memory_threshold Memory threshold in MB to trigger garbage collection
#' @param max_retries Maximum number of retry attempts per configuration
#' @return List with combined simulation results
run_multi_config_simulations <- function(configs, market_data, 
                                         runs_per_config = 100,
                                         batch_size = 5,
                                         memory_threshold = 5000,
                                         max_retries = 3) {
  # Setup parallel processing ONCE for the entire function
  n_cores <- parallel::detectCores() - 1
  cl <- makeCluster(n_cores)
  # Ensure cluster is stopped when function exits
  on.exit(stopCluster(cl), add = TRUE)
  registerDoParallel(cl)
  
  # Export necessary functions to all workers once
  clusterExport(cl, c("simulate_global_growth", 
                      "simulate_country",
                      "simulate_logistic",
                      "simulate_logistic_linear", 
                      "simulate_bilogistic",
                      "generate_growth_parameters"),
                envir = environment())
  
  # Initialize result containers
  all_global_results <- list()
  all_country_results <- list()
  all_config_features <- list()
  
  # Create batches of configurations to limit memory usage
  total_configs <- length(configs)
  batch_indices <- split(1:total_configs, ceiling(seq_along(1:total_configs)/batch_size))
  
  # Process each batch separately
  for (batch_idx in seq_along(batch_indices)) {
    batch <- batch_indices[[batch_idx]]
    message(paste("Processing batch", batch_idx, "of", length(batch_indices), 
                  "- configs", min(batch), "to", max(batch)))
    
    # Process each configuration in this batch
    batch_results <- list()
    
    for (i in batch) {
      # Set number of runs for this config
      current_config <- configs[[i]]
      current_config$n_runs <- runs_per_config
      
      # Calculate run ID offset for this configuration
      # This ensures run IDs don't overlap between configurations
      run_id_offset <- (i - 1) * runs_per_config
      
      # Check memory usage and garbage collect if needed
      mem_used <- as.numeric(gc(reset=TRUE)[2, 2]) # Memory in MB
      if (mem_used > memory_threshold) {
        message("Memory usage high (", round(mem_used), "MB). Running garbage collection...")
        gc(verbose = FALSE, full = TRUE)
      }
      
      # Run Monte Carlo with simplified retry logic
      message(paste("Running configuration", i, "of", total_configs))
      
      successful_results <- NULL
      attempt <- 1
      
      # Retry logic with fixed number of attempts
      while (is.null(successful_results) && attempt <= max_retries) {
        if (attempt > 1) {
          message(paste("Retry attempt", attempt, "for configuration", i))
        }
        
        tryCatch({
          results <- run_monte_carlo(
            current_config, 
            market_data, 
            config_id = i,
            run_id_offset = run_id_offset
          )
          
          # Verify reasonable number of successful runs
          if (results$n_successful_runs >= runs_per_config * 0.8) {
            successful_results <- results
          } else {
            message(paste("Too few successful runs:", results$n_successful_runs, 
                          "- retrying"))
          }
        }, error = function(e) {
          message(paste("Error in configuration", i, ":", e$message))
        })
        
        attempt <- attempt + 1
      }
      
      # If still no successful results after all retries, create empty placeholder
      if (is.null(successful_results)) {
        message(paste("Failed to get results for configuration", i, "after", 
                      max_retries, "attempts. Creating empty placeholder."))
        
        # Create minimal placeholder results
        successful_results <- list(
          global_results = tibble(
            year = integer(0), 
            global_deployment = numeric(0),
            run = character(0),
            config_id = integer(0)
          ),
          country_results = tibble(
            year = integer(0),
            deployment = numeric(0),
            market_size = numeric(0),
            weighted_deployment = numeric(0),
            model_type = character(0),
            country = character(0),
            run = character(0),
            config_id = integer(0)
          ),
          n_successful_runs = 0
        )
      }
      
      # Extract configuration parameters to add as features
      config_features <- list(
        # Growth distribution parameters
        k_dist_shape = current_config$parameter_ranges$k$shape,
        k_dist_rate = current_config$parameter_ranges$k$rate,
        k_dist_min = current_config$parameter_ranges$k$min,
        k_dist_max = current_config$parameter_ranges$k$max,
        
        L_dist_shape = current_config$parameter_ranges$L$shape,
        L_dist_rate = current_config$parameter_ranges$L$rate,
        L_dist_min = current_config$parameter_ranges$L$min,
        L_dist_max = current_config$parameter_ranges$L$max,
        
        takeoff_dist_mean = current_config$parameter_ranges$takeoff$mean,
        takeoff_dist_sd = current_config$parameter_ranges$takeoff$sd,
        takeoff_dist_min = current_config$parameter_ranges$takeoff$min,
        takeoff_dist_max = current_config$parameter_ranges$takeoff$max
      )
      
      # Add calculated means for easier interpretation
      config_features$k_dist_mean <- config_features$k_dist_shape / config_features$k_dist_rate
      config_features$L_dist_mean <- config_features$L_dist_shape / config_features$L_dist_rate
      
      # Add model type weights
      config_features$logistic_weight <- current_config$model_weights[1]
      config_features$logistic_linear_weight <- current_config$model_weights[2]
      config_features$bilogistic_weight <- current_config$model_weights[3]
      
      # Store results with config info
      batch_results[[i]] <- list(
        global_results = successful_results$global_results,
        country_results = successful_results$country_results,
        config_id = i,
        config_features = config_features
      )
      
      # Verify uniqueness of run IDs
      message(paste("Runs in this configuration:", successful_results$n_successful_runs))
      unique_runs_this_config <- unique(successful_results$global_results$run)
      message(paste("Unique run IDs in this configuration:", length(unique_runs_this_config)))
      
      # Force garbage collection to free memory
      gc(verbose = FALSE)
    }
    
    # Process batch results - extract and add to main result containers
    for (result in batch_results) {
      if (is.null(result)) next  # Skip NULL results
      
      # Add to global results list
      all_global_results[[length(all_global_results) + 1]] <- result$global_results %>%
        mutate(config_id = result$config_id)
      
      # Add to country results list
      all_country_results[[length(all_country_results) + 1]] <- result$country_results %>%
        mutate(config_id = result$config_id)
      
      # Add config features
      all_config_features[[length(all_config_features) + 1]] <- tibble(
        config_id = result$config_id
      ) %>% bind_cols(as_tibble(result$config_features))
    }
    
    # Clear batch results to free memory
    batch_results <- NULL
    gc(verbose = FALSE, full = TRUE)
    
    message(paste("Completed batch", batch_idx, "of", length(batch_indices)))
  }
  
  # Combine results while managing memory
  message("Combining all global results...")
  combined_global_results <- bind_rows(all_global_results)
  all_global_results <- NULL  # Free memory
  gc(verbose = FALSE)
  
  message("Combining all country results...")
  combined_country_results <- bind_rows(all_country_results)
  all_country_results <- NULL  # Free memory
  gc(verbose = FALSE)
  
  message("Combining all configuration features...")
  combined_config_features <- bind_rows(all_config_features)
  all_config_features <- NULL  # Free memory
  gc(verbose = FALSE)
  
  # Perform final uniqueness check on the combined results
  unique_runs_total <- unique(combined_global_results$run)
  message(paste("Total unique run IDs across all configurations:", length(unique_runs_total)))
  
  # Count occurrences of each run ID to check for duplicates
  run_counts <- combined_global_results %>%
    group_by(run) %>%
    summarize(
      num_years = length(unique(year)),
      total_rows = n(),
      .groups = "drop"
    ) %>%
    mutate(expected_rows = num_years) %>%
    filter(total_rows > expected_rows)
  
  if (nrow(run_counts) > 0) {
    warning("Potential duplicate run IDs detected!")
    message(paste("Duplicate run IDs found:", nrow(run_counts)))
    print(head(run_counts))
  } else {
    message("No duplicate run IDs detected in final results.")
  }
  
  return(list(
    global_results = combined_global_results,
    country_results = combined_country_results,
    config_features = combined_config_features
  ))
}
#--------------------
# 5. Analysis Functions
#--------------------

## Define all helper functions at the top level so they can be properly exported

#' Estimate growth parameters from national data
#' 
#' @param data Country data
#' @param min_points Minimum number of data points required
#' @return Tibble with estimated parameters
estimate_national_parameters_local <- function(data, min_points = 5) {
  # Ensure we return a tibble even if conditions aren't met
  if (nrow(data) < min_points || max(data$deployment, na.rm = TRUE) < 0.01) {
    return(tibble::tibble())
  }
  
  fit_data <- data.frame(
    Year = data$year,
    Value = data$deployment
  )
  
  tryCatch({
    fit <- fit_curve(fit_data, fit = "S")
    if(!is.null(fit) && fit$Good == 1 && fit$Maturity >= 0.3) {  
      tibble::tibble(
        k = fit$K,
        L = fit$L, 
        G = fit$G,
        Maturity = fit$Maturity,
        deployment = tail(data$deployment, 1),
        market_size = unique(data$market_size)
      )
    } else {
      tibble::tibble()
    }
  }, error = function(e) {
    tibble::tibble()
  })
}

#' Process a single cutoff year for curtailment analysis
#' 
#' @param cutoff_year Cutoff year to analyze
#' @param country_data Country results data
#' @param global_data Global results data
#' @param debug_mode Whether to output debug information
#' @param min_countries Minimum countries required per run
#' @return Data frame with parameter estimates
process_cutoff_year_local <- function(cutoff_year, country_results, global_results, 
                                      debug_mode = FALSE, min_countries = 3) {
  if (debug_mode) message(sprintf("Processing cutoff year %d", cutoff_year))
  
  # Filter data for this cutoff
  curtailed_data <- country_results %>% 
    dplyr::filter(year <= cutoff_year, deployment > 0)
  
  if (nrow(curtailed_data) == 0) {
    if (debug_mode) message(sprintf("No data for cutoff year %d", cutoff_year))
    return(NULL)
  }
  
  # Get global deployment at cutoff year
  global_deployment_data <- global_results %>%
    dplyr::filter(year == cutoff_year) %>%
    dplyr::select(run, global_deployment)
  
  # Get all runs with data at this cutoff
  all_runs <- unique(curtailed_data$run)
  if (debug_mode) message(sprintf("%d unique runs for cutoff year %d", length(all_runs), cutoff_year))
  
  # Process each run
  params_list <- list()
  
  for (i in seq_along(all_runs)) {
    current_run <- all_runs[i]
    
    # Get data for this run
    run_data <- curtailed_data %>% 
      dplyr::filter(run == current_run)
    
    # Process each country for this run
    params <- run_data %>%
      dplyr::group_by(country) %>%
      dplyr::group_modify(~ {
        est_params <- estimate_national_parameters_local(.x, min_points = 5)
        if (nrow(est_params) == 0) return(tibble::tibble())
        est_params
      }) %>%
      dplyr::ungroup() %>%
      drop_na() # Remove rows where fitting failed
    
    # If we have enough countries with valid parameters, calculate summary statistics
    if (nrow(params) >= min_countries) {
      # Store the mature countries data for additional metrics
      mature_countries <- params
      
      # Create a summary for this run
      run_summary <- tibble::tibble(
        run = current_run,
        k_median = median(params$k, na.rm = TRUE),
        k_q1 = unname(quantile(params$k, 0.25, na.rm = TRUE)),
        k_q3 = unname(quantile(params$k, 0.75, na.rm = TRUE)),
        L_median = median(params$L, na.rm = TRUE),
        L_q1 = unname(quantile(params$L, 0.25, na.rm = TRUE)),
        L_q3 = unname(quantile(params$L, 0.75, na.rm = TRUE)),
        G_median = median(params$G, na.rm = TRUE),
        G_q1 = unname(quantile(params$G, 0.25, na.rm = TRUE)),
        G_q3 = unname(quantile(params$G, 0.75, na.rm = TRUE)),
        
        # New features
        n_mature_countries = nrow(mature_countries),  # Count of mature countries
        mature_market_share = sum(mature_countries$market_size, na.rm = TRUE),  # Total market share
        
        n_countries = nrow(params),
        curtail_year = cutoff_year
      )
      
      params_list[[length(params_list) + 1]] <- run_summary
    }
  }
  
  # Combine results for all runs at this cutoff
  if (length(params_list) > 0) {
    result <- dplyr::bind_rows(params_list) %>%
      dplyr::left_join(global_deployment_data, by = "run")
    
    if (debug_mode) {
      message(sprintf("Generated parameters for %d runs at cutoff year %d", 
                      nrow(result), cutoff_year))
    }
    
    return(result)
  } else {
    if (debug_mode) {
      message(sprintf("No valid parameter estimates for cutoff year %d", cutoff_year))
    }
    return(NULL)
  }
}

#' Fit logistic curves to global trajectories
#' 
#' @param global_data Global results data
#' @param debug_mode Whether to output debug information
#' @return Data frame with fitted parameters
fit_global_curves_local <- function(global_data, debug_mode = FALSE) {
  # For each simulation run, fit a logistic curve to global trajectory
  global_runs <- unique(global_data$run)
  
  global_fits_list <- lapply(global_runs, function(current_run) {
    # Get data for this run
    run_data <- global_data %>% 
      dplyr::filter(run == current_run) %>%
      dplyr::arrange(year)
    
    # Prepare data for fitting
    fit_data <- data.frame(
      Year = run_data$year,
      Value = run_data$global_deployment,
      Total = 1  # Since we're working with shares
    )
    
    # Fit logistic curve
    tryCatch({
      fit_result <- fit_curve(fit_data, fit = "S")
      if(!is.null(fit_result) && fit_result$Good == 1) {
        return(tibble::tibble(
          run = current_run,
          k = fit_result$K,
          L = fit_result$L,
          G = fit_result$G,
          t0 = fit_result$TMax,
          rmse = sqrt(mean((fit_result$fitted - run_data$global_deployment)^2))
        ))
      } else {
        return(NULL)
      }
    }, error = function(e) {
      if (debug_mode) {
        message(sprintf("Error fitting global curve for run %s: %s", 
                        current_run, e$message))
      }
      return(NULL)
    })
  })
  
  # Remove NULL results
  global_fits_list <- global_fits_list[!sapply(global_fits_list, is.null)]
  
  # Combine results
  if (length(global_fits_list) > 0) {
    global_fits <- dplyr::bind_rows(global_fits_list)
    return(global_fits)
  } else {
    return(NULL)
  }
}

#' Prepare unified training data for modeling with global helper functions
#' 
#' @param multi_config_results Results from multi-configuration simulations
#' @param cutoff_years Years to use as cutoff points
#' @param use_parallel Whether to use parallel processing
#' @param debug_mode Enable additional diagnostic output
#' @param min_countries Minimum countries required per run
#' @return Data frame with training data

prepare_unified_training_data <- function(multi_config_results, cutoff_years, 
                                          use_parallel = FALSE, debug_mode = FALSE,
                                          min_countries = 3) {
  # Extract data
  global_results <- multi_config_results$global_results
  country_results <- multi_config_results$country_results
  config_features <- multi_config_results$config_features
  
  if (debug_mode) {
    message("Data dimensions:")
    message(sprintf("- global_results: %d rows, %d unique runs", 
                    nrow(global_results), length(unique(global_results$run))))
    message(sprintf("- country_results: %d rows, %d unique runs", 
                    nrow(country_results), length(unique(country_results$run))))
  }
  
  # Validate cutoff years
  min_year <- min(country_results$year)
  max_year <- max(country_results$year)
  valid_cutoffs <- cutoff_years[cutoff_years >= min_year & cutoff_years <= max_year]
  
  if (length(valid_cutoffs) == 0) {
    stop(sprintf("No valid cutoff years found. Data range is %d-%d, provided: %s", 
                 min_year, max_year, paste(cutoff_years, collapse=", ")))
  }
  
  message(sprintf("Using cutoff years: %s", paste(valid_cutoffs, collapse=", ")))
  
  # Process all cutoff years 
  # Use sequential processing for safety - remove parallel processing option temporarily
  message("Using sequential processing for curtailment analysis")
  
  # Process sequentially
  curtailed_params_list <- lapply(valid_cutoffs, function(cutoff) {
    process_cutoff_year_local(cutoff, country_results, global_results, 
                              debug_mode = debug_mode, min_countries = min_countries)
  })
  
  # Remove NULL results
  curtailed_params_list <- curtailed_params_list[!sapply(curtailed_params_list, is.null)]
  
  if (length(curtailed_params_list) == 0) {
    if (debug_mode) {
      # Print diagnostics to help troubleshoot
      message("Run diagnostic info:")
      for (cutoff in valid_cutoffs) {
        filtered_data <- country_results %>% 
          dplyr::filter(year <= cutoff, deployment > 0)
        
        run_counts <- filtered_data %>%
          dplyr::group_by(run) %>%
          dplyr::summarize(
            n_countries = n_distinct(country),
            max_year = max(year),
            max_deployment = max(deployment)
          )
        
        message(sprintf("Cutoff %d: %d runs, %d with ≥%d countries", 
                        cutoff, nrow(run_counts), 
                        sum(run_counts$n_countries >= min_countries), min_countries))
      }
    }
    
    stop("No valid parameter estimates generated for any cutoff year. Try reducing min_countries or enabling debug_mode for more info.")
  }
  
  # Combine results from all cutoff years
  curtailed_params <- dplyr::bind_rows(curtailed_params_list)
  message(sprintf("Combined parameter estimates: %d rows", nrow(curtailed_params)))
  
  # Fit global curves
  message("Fitting global curves...")
  global_fits <- fit_global_curves_local(global_results, debug_mode = debug_mode)
  
  if (is.null(global_fits) || nrow(global_fits) == 0) {
    stop("No valid global curve fits generated. Check your data.")
  }
  
  message(sprintf("Generated %d global curve fits", nrow(global_fits)))
  
  # Create run-to-config mapping
  run_config_mapping <- global_results %>%
    dplyr::select(run, config_id) %>%
    dplyr::distinct()
  
  # Join datasets
  message("Joining datasets...")
  
  # Ensure consistent run column types for joining
  curtailed_params$run <- as.character(curtailed_params$run)
  global_fits$run <- as.character(global_fits$run)
  run_config_mapping$run <- as.character(run_config_mapping$run)
  
  # Step 1: Join curtailed parameters with global fits
  message("Step 1: Joining curtailed parameters with global fits...")
  
  # Show sample run IDs to help diagnose join issues
  if (debug_mode) {
    message("Sample run IDs from curtailed_params:")
    print(head(unique(curtailed_params$run)))
    
    message("Sample run IDs from global_fits:")
    print(head(unique(global_fits$run)))
  }
  
  step1 <- dplyr::inner_join(curtailed_params, global_fits, by = "run")
  message(sprintf("After step 1 join: %d rows", nrow(step1)))
  
  if (nrow(step1) == 0) {
    stop("No matches found when joining curtailed parameters with global fits. Run IDs may not be consistent.")
  }
  
  # Step 2: Join with run/config mapping
  message("Step 2: Joining with run/config mapping...")
  step2 <- dplyr::inner_join(step1, run_config_mapping, by = "run")
  message(sprintf("After step 2 join: %d rows", nrow(step2)))
  
  if (nrow(step2) == 0) {
    stop("No matches found when joining with run/config mapping.")
  }
  
  # Step 3: Join with configuration features
  message("Step 3: Joining with configuration features...")
  if (!is.null(config_features) && nrow(config_features) > 0) {
    training_data <- dplyr::left_join(step2, config_features, by = "config_id")
  } else {
    message("No config features available - skipping this join")
    training_data <- step2
  }
  
  # Final filtering
  final_data <- training_data %>%
    dplyr::filter(!is.na(k_median), !is.na(G_median))
  
  message(sprintf("Final training dataset: %d rows", nrow(final_data)))
  
  if (nrow(final_data) == 0) {
    stop("Final dataset is empty after filtering. Check your data.")
  }
  
  return(final_data)
}

#' Build unified projection model with clear feature separation
#' 
#' @param training_data Prepared training data
#' @param keep_forest Whether to keep forest data for prediction intervals
#' @param use_config_features Whether to use configuration features for training
#' @return List with model objects and metadata

build_unified_projection_model <- function(training_data, keep_forest = TRUE, use_config_features = FALSE) {
  # Define basic empirical features (always required for prediction)
  empirical_features <- c(
    "k_median", "k_q1", "k_q3", 
    "G_median", "G_q1", "G_q3", 
    # "L_median", "L_q1", "L_q3",
    # "global_deployment",
    "n_mature_countries",
    "mature_market_share"     # New feature
  )
  
  # Define configuration features (optional, used only during training)
  config_features <- c(
    "k_dist_mean", "L_dist_mean", "takeoff_dist_mean",
    "logistic_weight", "logistic_linear_weight", "bilogistic_weight"
  )
  
  # Check which config features are available in the data
  available_config_features <- config_features[config_features %in% names(training_data)]
  
  # Combine features based on whether to use config features
  if (use_config_features && length(available_config_features) > 0) {
    feature_cols <- c(empirical_features, available_config_features)
    message("Using both empirical and configuration features for model training")
  } else {
    feature_cols <- empirical_features
    message("Using only empirical features for model training")
  }
  
  message(paste("Features used for training:", paste(feature_cols, collapse=", ")))
  
  # Train models with feature set
  k_model <- ranger(
    formula = as.formula(paste("k ~", paste(feature_cols, collapse = " + "))),
    data = training_data,
    num.trees = 1000,
    importance = "impurity",
    quantreg = TRUE,
    keep.inbag = keep_forest
  )
  
  G_model <- ranger(
    formula = as.formula(paste("G ~", paste(feature_cols, collapse = " + "))),
    data = training_data,
    num.trees = 1000,
    importance = "impurity", 
    quantreg = TRUE,
    keep.inbag = keep_forest
  )
  
  # Calculate feature importance
  k_importance <- data.frame(
    feature = names(importance(k_model)),
    importance_k = importance(k_model),
    stringsAsFactors = FALSE
  )
  
  G_importance <- data.frame(
    feature = names(importance(G_model)),
    importance_G = importance(G_model),
    stringsAsFactors = FALSE
  )
  
  feature_importance <- full_join(k_importance, G_importance, by = "feature") %>%
    arrange(desc(importance_k + importance_G))
  
  # Return models and metadata
  return(list(
    k_model = k_model,
    G_model = G_model,
    empirical_features = empirical_features,
    config_features = available_config_features,
    all_features = feature_cols,
    feature_importance = feature_importance
  ))
}


#' Alternative build model function that only uses empirical features
#' 
#' @param training_data Prepared training data
#' @param keep_forest Whether to keep forest data for prediction intervals
#' @return List with model objects and metadata
build_empirical_only_model <- function(training_data, keep_forest = TRUE) {
  return(build_unified_projection_model(training_data, keep_forest, use_config_features = FALSE))
}


#' Analyze feature importance
#' 
#' @param unified_model Model from build_unified_projection_model
#' @return List with plots and importance data
analyze_feature_importance <- function(unified_model) {
  # Get feature importance data
  importance_data <- unified_model$feature_importance %>%
    # Calculate total importance across both models
    mutate(
      total_importance = importance_k + importance_G,
      # Normalize to percentage
      importance_k_pct = 100 * importance_k / sum(importance_k),
      importance_G_pct = 100 * importance_G / sum(importance_G)
    ) %>%
    # Split features into categories
    mutate(
      feature_type = case_when(
        grepl("k_", feature) & !grepl("dist", feature) ~ "Growth Rate Parameters",
        grepl("G_", feature) & !grepl("dist", feature) ~ "Growth Factor Parameters",
        grepl("global_deployment", feature) ~ "Current Deployment",
        grepl("dist", feature) ~ "Configuration Distribution",
        grepl("weight", feature) ~ "Model Weights",
        TRUE ~ "Other"
      )
    )
  
  # Create feature importance plot
  p1 <- ggplot(importance_data, aes(x = reorder(feature, total_importance))) +
    geom_col(aes(y = importance_k_pct), fill = "steelblue") +
    geom_col(aes(y = importance_G_pct), fill = "darkred", alpha = 0.7) +
    coord_flip() +
    labs(
      title = "Feature Importance in Unified Model",
      x = "", 
      y = "Importance (%)"
    ) +
    theme_minimal() +
    theme(
      panel.grid.major.y = element_blank(),
      axis.text.y = element_text(size = 10)
    )
  
  # Create feature category summary
  feature_type_summary <- importance_data %>%
    group_by(feature_type) %>%
    summarize(
      total_importance = sum(total_importance),
      k_importance = sum(importance_k),
      G_importance = sum(importance_G),
      n_features = n()
    ) %>%
    arrange(desc(total_importance))
  
  # Plot importance by category
  p2 <- ggplot(feature_type_summary, aes(x = reorder(feature_type, total_importance))) +
    geom_col(aes(y = k_importance), fill = "steelblue") +
    geom_col(aes(y = G_importance), fill = "darkred", alpha = 0.7) +
    coord_flip() +
    labs(
      title = "Importance by Feature Category",
      x = "",
      y = "Importance (sum)"
    ) +
    theme_minimal()
  
  # Return plots and data
  return(list(
    importance_data = importance_data,
    feature_type_summary = feature_type_summary,
    feature_plot = p1,
    category_plot = p2
  ))
}

#' Analyze configuration impact
#' 
#' @param training_data Prepared training data
#' @param diverse_configs Output from create_distribution_variations
#' @return List with plots and analysis data

analyze_config_impact <- function(training_data, diverse_configs) {
  # Join training data with config summaries
  analysis_data <- training_data %>%
    left_join(diverse_configs$summary, by = "config_id")
  
  # Plot distribution of final parameters by config
  k_dist_plot <- ggplot(analysis_data, aes(x = factor(config_id), y = k)) +
    geom_boxplot(aes(fill = description), alpha = 0.7) +
    geom_point(aes(y = k_dist_mean), color = "red", size = 3, shape = 18) +
    labs(
      title = "Distribution of k parameters by configuration",
      subtitle = "Red diamonds show the mean of the input distribution",
      x = "Configuration ID",
      y = "k (growth rate parameter)"
    ) +
    theme_minimal()
  
  # Create L parameter using the relationship L = 4*G/k
  L_dist_plot <- ggplot(analysis_data, aes(x = factor(config_id), y = 4*G/k)) +
    geom_boxplot(aes(fill = description), alpha = 0.7) +
    geom_point(aes(y = L_mean), color = "red", size = 3, shape = 18) +
    labs(
      title = "Distribution of L parameters by configuration",
      subtitle = "Red diamonds show the mean of the input distribution",
      x = "Configuration ID",
      y = "L (saturation level)"
    ) +
    theme_minimal()
  
  # Analyze parameter correlations
  correlation_data <- analysis_data %>%
    group_by(config_id, description) %>%
    summarize(
      k_G_correlation = cor(k, G, method = "spearman"),
      input_k_output_k = cor(k_dist_mean, k, method = "spearman"),
      input_L_output_L = cor(L_mean, 4*G/k, method = "spearman"),
      n = n(),
      .groups = "drop"
    )
  
  # Plot how input distributions correlate with outputs
  correlation_plot <- ggplot(correlation_data, aes(x = factor(config_id))) +
    geom_col(aes(y = input_k_output_k), fill = "steelblue") +
    geom_col(aes(y = input_L_output_L), fill = "darkred", alpha = 0.5) +
    geom_hline(yintercept = 0, linetype = "dashed") +
    labs(
      title = "Correlation between input distributions and output parameters",
      subtitle = "Blue = k parameter, Red = L parameter",
      x = "Configuration ID",
      y = "Spearman correlation"
    ) +
    theme_minimal() +
    ylim(-1, 1)
  
  # Return all results
  return(list(
    k_dist_plot = k_dist_plot,
    L_dist_plot = L_dist_plot,
    correlation_plot = correlation_plot,
    correlation_data = correlation_data
  ))
}


extract_country_features <- function(run_data) {
  # Get top countries by market size
  top_countries <- run_data %>%
    group_by(country) %>%
    slice(1) %>%
    ungroup() %>%
    arrange(desc(market_size)) %>%
    head(10) %>%
    pull(country)
  
  # Extract growth parameters for these countries
  country_features <- run_data %>%
    filter(country %in% top_countries) %>%
    group_by(country) %>%
    summarize(
      market_size = first(market_size),
      k = first(k),
      L = first(L),
      takeoff_year = first(takeoff_year),
      t0 = first(t0),
      max_deployment = max(deployment),
      # Default to Inf if no value >= 0.5 exists
      time_to_50pct = ifelse(any(deployment >= 0.5), 
                             min(year[deployment >= 0.5]), 
                             Inf),
      .groups = "drop"
    ) %>%
    arrange(desc(market_size))
  
  # Create market size buckets
  country_features <- country_features %>%
    mutate(
      size_category = ntile(market_size, 3),  # Small, medium, large countries
      size_category = factor(size_category, labels = c("Small", "Medium", "Large"))
    )
  
  return(country_features)
}

calculate_relationship_metrics <- function(country_features) {
  # Handle cases with insufficient data
  if (nrow(country_features) < 5) {
    return(list(
      corr_size_k = NA_real_,
      corr_size_L = NA_real_,
      corr_size_takeoff = NA_real_,
      large_vs_small_k_ratio = NA_real_,
      large_vs_small_L_ratio = NA_real_,
      takeoff_spread = NA_real_,
      leading_market_share = NA_real_
    ))
  }
  
  # Size-parameter correlations
  corr_size_k <- tryCatch(
    cor(country_features$market_size, country_features$k, method = "spearman"),
    error = function(e) NA_real_
  )
  
  corr_size_L <- tryCatch(
    cor(country_features$market_size, country_features$L, method = "spearman"),
    error = function(e) NA_real_
  )
  
  corr_size_takeoff <- tryCatch(
    cor(country_features$market_size, country_features$takeoff_year, method = "spearman"),
    error = function(e) NA_real_
  )
  
  # Growth pattern by size category
  by_size <- tryCatch({
    country_features %>%
      group_by(size_category) %>%
      summarize(
        mean_k = mean(k, na.rm = TRUE),
        mean_L = mean(L, na.rm = TRUE),
        mean_takeoff = mean(takeoff_year, na.rm = TRUE),
        .groups = "drop"
      )
  }, error = function(e) {
    data.frame(
      size_category = factor(c("Small", "Medium", "Large")),
      mean_k = NA_real_,
      mean_L = NA_real_,
      mean_takeoff = NA_real_
    )
  })
  
  # Safely calculate ratios
  large_k <- by_size$mean_k[by_size$size_category == "Large"]
  small_k <- by_size$mean_k[by_size$size_category == "Small"]
  large_vs_small_k_ratio <- if (length(large_k) > 0 && length(small_k) > 0 && !is.na(small_k) && small_k != 0) {
    large_k / small_k
  } else {
    NA_real_
  }
  
  large_L <- by_size$mean_L[by_size$size_category == "Large"]
  small_L <- by_size$mean_L[by_size$size_category == "Small"]
  large_vs_small_L_ratio <- if (length(large_L) > 0 && length(small_L) > 0 && !is.na(small_L) && small_L != 0) {
    large_L / small_L
  } else {
    NA_real_
  }
  
  # Leader-follower patterns
  takeoff_range <- tryCatch(
    diff(range(country_features$takeoff_year, na.rm = TRUE)),
    error = function(e) NA_real_
  )
  
  leading_countries <- tryCatch({
    country_features %>%
      arrange(takeoff_year) %>%
      head(3) %>%
      pull(market_size) %>%
      sum()
  }, error = function(e) NA_real_)
  
  leading_market_share <- if (!is.na(leading_countries)) {
    leading_countries / sum(country_features$market_size, na.rm = TRUE)
  } else {
    NA_real_
  }
  
  return(list(
    corr_size_k = corr_size_k,
    corr_size_L = corr_size_L,
    corr_size_takeoff = corr_size_takeoff,
    large_vs_small_k_ratio = large_vs_small_k_ratio,
    large_vs_small_L_ratio = large_vs_small_L_ratio,
    takeoff_spread = takeoff_range,
    leading_market_share = leading_market_share
  ))
}

calculate_heterogeneity_metrics <- function(country_features) {
  # Handle cases with insufficient data
  if (nrow(country_features) < 5) {
    return(list(
      cv_k = NA_real_,
      cv_L = NA_real_,
      gini_takeoff = NA_real_,
      gini_deployment = NA_real_,
      cluster_separation = NA_real_
    ))
  }
  
  # Parameter dispersion
  cv_k <- tryCatch({
    mean_k <- mean(country_features$k, na.rm = TRUE)
    if (is.na(mean_k) || mean_k == 0) NA_real_ else sd(country_features$k, na.rm = TRUE) / mean_k
  }, error = function(e) NA_real_)
  
  cv_L <- tryCatch({
    mean_L <- mean(country_features$L, na.rm = TRUE)
    if (is.na(mean_L) || mean_L == 0) NA_real_ else sd(country_features$L, na.rm = TRUE) / mean_L
  }, error = function(e) NA_real_)
  
  # Simple dispersion measures instead of Gini (avoiding additional package dependencies)
  gini_takeoff <- tryCatch({
    takeoff_range <- diff(range(country_features$takeoff_year, na.rm = TRUE))
    takeoff_mean <- mean(country_features$takeoff_year, na.rm = TRUE)
    if (takeoff_mean == 0) NA_real_ else takeoff_range / takeoff_mean
  }, error = function(e) NA_real_)
  
  gini_deployment <- tryCatch({
    deployment_range <- diff(range(country_features$max_deployment, na.rm = TRUE))
    deployment_mean <- mean(country_features$max_deployment, na.rm = TRUE)
    if (deployment_mean == 0) NA_real_ else deployment_range / deployment_mean
  }, error = function(e) NA_real_)
  
  # Clustering of growth patterns
  cluster_separation <- tryCatch({
    # Prepare data for k-means
    km_data <- country_features[, c("k", "L", "takeoff_year")]
    km_data <- km_data[complete.cases(km_data), ]
    
    if (nrow(km_data) >= 5) {
      kmeans_result <- kmeans(scale(km_data), centers = 2)
      kmeans_result$betweenss / kmeans_result$totss
    } else {
      NA_real_
    }
  }, error = function(e) NA_real_)
  
  return(list(
    cv_k = cv_k,
    cv_L = cv_L,
    gini_takeoff = gini_takeoff,
    gini_deployment = gini_deployment,
    cluster_separation = cluster_separation
  ))
}

create_run_feature_vector <- function(run_data) {
  # Extract basic features
  country_features <- extract_country_features(run_data)
  relationship_metrics <- calculate_relationship_metrics(country_features)
  heterogeneity_metrics <- calculate_heterogeneity_metrics(country_features)
  
  # Check if we have minimal data
  if (nrow(country_features) < 5) {
    # Return a vector of NAs with the right names
    feature_names <- c(
      "large_country_avg_k", "large_country_avg_L", "large_country_avg_takeoff",
      names(unlist(relationship_metrics)), names(unlist(heterogeneity_metrics)),
      "large_lead_small", "global_max_deployment", "years_to_global_50pct"
    )
    return(rep(NA_real_, length(feature_names)))
  }
  
  # Create size-weighted metrics
  large_country_data <- country_features %>% filter(size_category == "Large")
  if (nrow(large_country_data) > 0) {
    large_country_avg_k <- weighted.mean(large_country_data$k, large_country_data$market_size, na.rm = TRUE)
    large_country_avg_L <- weighted.mean(large_country_data$L, large_country_data$market_size, na.rm = TRUE)
    large_country_avg_takeoff <- weighted.mean(large_country_data$takeoff_year, large_country_data$market_size, na.rm = TRUE)
  } else {
    large_country_avg_k <- NA_real_
    large_country_avg_L <- NA_real_
    large_country_avg_takeoff <- NA_real_
  }
  
  # Development sequence metrics
  large_takeoff <- mean(country_features$takeoff_year[country_features$size_category == "Large"], na.rm = TRUE)
  small_takeoff <- mean(country_features$takeoff_year[country_features$size_category == "Small"], na.rm = TRUE)
  large_lead_small <- !is.na(large_takeoff) && !is.na(small_takeoff) && large_takeoff < small_takeoff
  
  # Global pattern indicators
  # First check if global_deployment column exists in the data
  has_global_deployment <- any(grepl("global_deployment", names(run_data)))
  
  if (has_global_deployment) {
    # Get the actual global deployment column name (it might be prefixed)
    global_col <- names(run_data)[grepl("global_deployment", names(run_data))]
    
    # Access global deployment data - this assumes there's only one global_deployment column
    if (length(global_col) > 0) {
      global_col <- global_col[1]  # Take the first one if multiple matches
      global_max_deployment <- max(run_data[[global_col]], na.rm = TRUE)
      
      # Find when it crosses 50%
      years_to_global_50pct <- tryCatch({
        global_50pct_year <- min(run_data$year[run_data[[global_col]] >= 0.5], na.rm = TRUE)
        min_year <- min(run_data$year, na.rm = TRUE)
        global_50pct_year - min_year
      }, error = function(e) NA_real_)
    } else {
      global_max_deployment <- NA_real_
      years_to_global_50pct <- NA_real_
    }
  } else {
    # Try to calculate from country data if no global column
    global_max_deployment <- tryCatch({
      # For each year, sum weighted deployment across countries
      run_data %>%
        group_by(year) %>%
        summarize(global = sum(deployment * market_size, na.rm = TRUE)) %>%
        pull(global) %>%
        max(na.rm = TRUE)
    }, error = function(e) NA_real_)
    
    years_to_global_50pct <- NA_real_  # Can't easily calculate without year-by-year
  }
  
  # Create feature vector
  feature_vector <- c(
    large_country_avg_k = large_country_avg_k,
    large_country_avg_L = large_country_avg_L,
    large_country_avg_takeoff = large_country_avg_takeoff,
    unlist(relationship_metrics),
    unlist(heterogeneity_metrics),
    large_lead_small = as.numeric(large_lead_small),
    global_max_deployment = global_max_deployment,
    years_to_global_50pct = years_to_global_50pct
  )
  
  return(feature_vector)
}

select_diverse_subset <- function(all_runs, n_clusters = 50, debug = FALSE) {
  # Get unique run IDs
  run_ids <- unique(all_runs$run)
  n_total_runs <- length(run_ids)
  
  # Adjust n_clusters if fewer runs than requested clusters
  if (n_total_runs < n_clusters) {
    message(sprintf("Only %d runs available, reducing clusters to match", n_total_runs))
    n_clusters <- n_total_runs
  }
  
  if (debug) message(sprintf("Processing %d runs for diversity analysis", n_total_runs))
  
  # Extract feature vector for each run
  feature_matrix <- matrix(NA, nrow = n_total_runs, ncol = 0)
  feature_names <- NULL
  
  for (i in seq_along(run_ids)) {
    if (debug && i %% 20 == 0) message(sprintf("Processing run %d of %d", i, n_total_runs))
    
    run_data <- all_runs %>% filter(run == run_ids[i])
    feature_vector <- create_run_feature_vector(run_data)
    
    # Store feature names from first run
    if (i == 1) {
      feature_names <- names(feature_vector)
      feature_matrix <- matrix(NA, nrow = n_total_runs, ncol = length(feature_vector))
      colnames(feature_matrix) <- feature_names
    }
    
    feature_matrix[i, ] <- as.numeric(feature_vector)
  }
  
  # Handle missing values
  feature_matrix[is.infinite(feature_matrix)] <- NA
  
  # Check if enough complete cases to proceed
  complete_rows <- complete.cases(feature_matrix)
  if (sum(complete_rows) < 10) {
    warning("Too few complete cases for clustering, returning random subset")
    return(sample(run_ids, min(n_clusters, length(run_ids))))
  }
  
  # Use only complete cases for clustering
  reduced_feature_matrix <- feature_matrix[complete_rows, ]
  reduced_run_ids <- run_ids[complete_rows]
  
  if (debug) message(sprintf("Using %d complete runs for clustering", nrow(reduced_feature_matrix)))
  
  # Normalize features
  feature_matrix_scaled <- scale(reduced_feature_matrix)
  
  # Apply PCA to reduce dimensionality
  pca_result <- tryCatch({
    prcomp(feature_matrix_scaled)
  }, error = function(e) {
    warning("PCA failed, using original features")
    NULL
  })
  
  if (!is.null(pca_result)) {
    var_explained <- cumsum(pca_result$sdev^2/sum(pca_result$sdev^2))
    components_to_keep <- which(var_explained >= 0.85)[1]
    components_to_keep <- min(components_to_keep, ncol(feature_matrix_scaled) / 2) # Keep at most half the features
    reduced_features <- pca_result$x[, 1:components_to_keep]
    
    if (debug) message(sprintf("Using %d principal components explaining 85%% of variance", components_to_keep))
  } else {
    reduced_features <- feature_matrix_scaled
  }
  
  # Apply k-means clustering
  k_means_result <- tryCatch({
    kmeans(reduced_features, centers = min(n_clusters, nrow(reduced_features)-1))
  }, error = function(e) {
    warning("K-means clustering failed, returning random subset")
    return(sample(reduced_run_ids, min(n_clusters, length(reduced_run_ids))))
  })
  
  # Select medoid from each cluster
  selected_runs <- vector("list", n_clusters)
  
  for (i in 1:length(unique(k_means_result$cluster))) {
    cluster_members <- which(k_means_result$cluster == i)
    
    if (length(cluster_members) == 0) {
      next  # Skip empty clusters
    } else if (length(cluster_members) == 1) {
      # Only one member - select it
      selected_runs[[i]] <- reduced_run_ids[cluster_members]
    } else {
      # Multiple members - select the one closest to centroid
      cluster_center <- k_means_result$centers[i, ]
      distances <- apply(reduced_features[cluster_members, , drop = FALSE], 1, 
                         function(x) sum((x - cluster_center)^2))
      closest_idx <- cluster_members[which.min(distances)]
      selected_runs[[i]] <- reduced_run_ids[closest_idx]
    }
  }
  
  # Remove NULL entries
  selected_runs <- selected_runs[!sapply(selected_runs, is.null)]
  selected_run_ids <- unlist(selected_runs)
  
  if (debug) message(sprintf("Selected %d diverse runs", length(selected_run_ids)))
  
  return(selected_run_ids)
}

visualize_growth_patterns <- function(all_runs, selected_runs, max_points = 1000) {
  # Get unique run IDs
  run_ids <- unique(all_runs$run)
  
  # Subsample if too many runs
  if (length(run_ids) > max_points) {
    set.seed(42)
    # Keep all selected runs, plus random sample of unselected
    unselected_runs <- setdiff(run_ids, selected_runs)
    sampled_unselected <- sample(unselected_runs, min(max_points - length(selected_runs), length(unselected_runs)))
    run_ids <- c(selected_runs, sampled_unselected)
  }
  
  # Extract feature vector for each run
  feature_matrix <- matrix(NA, nrow = length(run_ids), ncol = 0)
  
  for (i in seq_along(run_ids)) {
    run_data <- all_runs %>% filter(run == run_ids[i])
    feature_vector <- create_run_feature_vector(run_data)
    
    if (i == 1) {
      feature_matrix <- matrix(NA, nrow = length(run_ids), ncol = length(feature_vector))
    }
    
    feature_matrix[i, ] <- as.numeric(feature_vector)
  }
  
  # Handle missing values and only use complete cases
  feature_matrix[is.infinite(feature_matrix)] <- NA
  complete_rows <- complete.cases(feature_matrix)
  feature_matrix <- feature_matrix[complete_rows, ]
  run_ids <- run_ids[complete_rows]
  
  # Use t-SNE for visualization if we have enough data
  if (nrow(feature_matrix) < 10) {
    warning("Too few complete cases for t-SNE visualization")
    return(NULL)
  }
  
  # Scale the features
  feature_matrix_scaled <- scale(feature_matrix)
  
  # Use t-SNE with error handling
  tsne_result <- tryCatch({
    Rtsne::Rtsne(feature_matrix_scaled, perplexity = min(30, nrow(feature_matrix_scaled)/4))
  }, error = function(e) {
    warning("t-SNE failed: ", e$message)
    NULL
  })
  
  if (is.null(tsne_result)) {
    return(NULL)
  }
  
  # Create plotting data
  plot_data <- tibble(
    run_id = run_ids,
    x = tsne_result$Y[,1],
    y = tsne_result$Y[,2],
    selected = run_id %in% selected_runs
  )
  
  # Create visualization
  ggplot(plot_data, aes(x = x, y = y, color = selected)) +
    geom_point(size = ifelse(plot_data$selected, 3, 1), alpha = ifelse(plot_data$selected, 1, 0.3)) +
    scale_color_manual(values = c("gray70", "red")) +
    labs(
      title = "Diversity of Growth Patterns Across Simulation Runs",
      subtitle = paste(sum(plot_data$selected), "selected runs highlighted"),
      color = "Selected for\nCurve Fitting"
    ) +
    theme_minimal() +
    theme(legend.position = "bottom")
}

#-------------------------------------------------------
# 3. Integrated Workflow - Apply to Each Configuration
#-------------------------------------------------------

select_diverse_runs_by_config <- function(multi_config_results, runs_per_config = 75, debug = TRUE) {
  # Get all configurations
  all_configs <- unique(multi_config_results$global_results$config_id)
  
  # Create container for selected runs
  selected_runs_all <- list()
  
  # Process each configuration separately
  for (config_id in all_configs) {
    if (debug) message(sprintf("\nProcessing configuration %s", config_id))
    
    # Extract data for this configuration
    config_data <- list(
      global_results = multi_config_results$global_results %>% filter(config_id == !!config_id),
      country_results = multi_config_results$country_results %>% filter(config_id == !!config_id)
    )
    
    # Count runs for this configuration
    config_runs <- unique(config_data$global_results$run)
    n_config_runs <- length(config_runs)
    
    if (debug) message(sprintf("Configuration %s has %d total runs", config_id, n_config_runs))
    
    # Select diverse subset
    n_clusters <- min(runs_per_config, n_config_runs)
    
    if (debug) message(sprintf("Selecting %d diverse runs from configuration %s", n_clusters, config_id))
    
    selected_runs <- select_diverse_subset(
      all_runs = config_data$country_results,
      n_clusters = n_clusters,
      debug = debug
    )
    
    # Visualize if requested
    if (debug) {
      vis_plot <- visualize_growth_patterns(
        all_runs = config_data$country_results,
        selected_runs = selected_runs
      )
      
      if (!is.null(vis_plot)) {
        print(vis_plot + ggtitle(sprintf("Configuration %s - Growth Pattern Diversity", config_id)))
      }
    }
    
    # Store results
    selected_runs_all[[as.character(config_id)]] <- selected_runs
  }
  
  # Flatten results for easier use
  all_selected_runs <- unlist(selected_runs_all)
  
  if (debug) {
    message(sprintf("\nSelected %d total runs across %d configurations", 
                    length(all_selected_runs), length(all_configs)))
  }
  
  return(list(
    selected_runs = all_selected_runs,
    by_config = selected_runs_all
  ))
}

#-------------------------------------------------------
# 4. Efficient Training Data Generation
#-------------------------------------------------------

prepare_efficient_training_data <- function(multi_config_results, cutoff_years, 
                                            runs_per_config = 75, debug = TRUE) {
  # Step 1: Select diverse subset of runs
  diverse_selection <- select_diverse_runs_by_config(
    multi_config_results = multi_config_results,
    runs_per_config = runs_per_config,
    debug = debug
  )
  
  # Step 2: Create filtered version of results with only selected runs
  selected_runs <- diverse_selection$selected_runs
  
  filtered_results <- list(
    global_results = multi_config_results$global_results %>% 
      filter(run %in% selected_runs),
    country_results = multi_config_results$country_results %>% 
      filter(run %in% selected_runs)
  )
  
  if (debug) {
    message("\nFiltered from original dataset:")
    message(sprintf("Global results: %d -> %d rows", 
                    nrow(multi_config_results$global_results), 
                    nrow(filtered_results$global_results)))
    message(sprintf("Country results: %d -> %d rows", 
                    nrow(multi_config_results$country_results), 
                    nrow(filtered_results$country_results)))
  }
  
  # Step 3: Run the regular training data preparation on filtered subset
  training_data <- prepare_unified_training_data(
    multi_config_results = filtered_results,
    cutoff_years = cutoff_years,
    debug_mode = debug
  )
  
  return(list(
    training_data = training_data,
    diverse_selection = diverse_selection
  ))
}

#--------------------
# 6. Projection Functions
#--------------------

#' Predict growth parameters from model, handling missing configuration features
#' 
#' @param model Projection model
#' @param new_data New data with predictor variables
#' @return List of predicted parameters at different quantiles
predict_parameters <- function(model, new_data) {
  # Identify missing features needed by the model
  missing_features <- setdiff(model$all_features, names(new_data))
  
  # Create a copy of new_data to avoid modifying the original
  prediction_data <- new_data
  
  # If features are missing, add them with default values
  if (length(missing_features) > 0) {
    message("Adding missing features for prediction: ", paste(missing_features, collapse=", "))
    
    for (feature in missing_features) {
      # Add feature with median value from training (or 0 if not available)
      prediction_data[[feature]] <- 0
    }
  }
  
  # Predict k and G at corresponding quantiles
  quantiles <- c(0.05, 0.25, 0.5, 0.75, 0.95)
  
  k_quantiles <- predict(model$k_model, data = prediction_data, type = "quantiles", 
                         quantiles = quantiles)$predictions
  
  G_quantiles <- predict(model$G_model, data = prediction_data, type = "quantiles",
                         quantiles = quantiles)$predictions
  
  # Derive L from the mathematical relationship
  L_quantiles <- 4 * G_quantiles / k_quantiles
  
  list(
    k_quantiles = k_quantiles,
    G_quantiles = G_quantiles,
    L_quantiles = L_quantiles,
    quantile_levels = quantiles
  )
}

#' Project technology diffusion trajectory
#' 
#' @param national_stats National parameter statistics
#' @param projection_model Model from build_unified_projection_model
#' @param last_year Final year for projection
#' @return List with projected trajectories and parameters

project_trajectory <- function(national_stats, projection_model, last_year) {
  # Get parameter predictions from QRF
  params <- predict_parameters(projection_model, national_stats)
  
  # Initial values
  yi <- national_stats$global_deployment
  ti <- national_stats$Year
  
  # Create all possible parameter combinations
  trajectory_grid <- expand.grid(
    k_idx = 1:length(params$quantile_levels),
    G_idx = 1:length(params$quantile_levels)
  )
  trajectory_grid$trajectory_id <- 1:nrow(trajectory_grid)
  
  # Store all parameter combinations
  parameter_map <- tibble(
    trajectory_id = trajectory_grid$trajectory_id,
    k = params$k_quantiles[trajectory_grid$k_idx],
    G = params$G_quantiles[trajectory_grid$G_idx],
    L = 4 * params$G_quantiles[trajectory_grid$G_idx] / params$k_quantiles[trajectory_grid$k_idx],
    k_quantile = params$quantile_levels[trajectory_grid$k_idx],
    G_quantile = params$quantile_levels[trajectory_grid$G_idx]
  )
  
  # Skip invalid combinations (L <= yi)
  valid_sets <- parameter_map[parameter_map$L > yi,]
  
  # Generate all valid trajectories
  all_trajectories <- list()
  for (i in 1:nrow(valid_sets)) {
    k <- valid_sets$k[i]
    G <- valid_sets$G[i]
    L <- valid_sets$L[i]
    trajectory_id <- valid_sets$trajectory_id[i]
    
    # Calculate inflection point
    t0 <- ((1/k) * log((L/yi) - 1)) + ti
    
    # Generate trajectory
    years <- seq(ti, last_year, by=1)
    deployment <- L / (1 + exp(-k * (years - t0)))
    
    all_trajectories[[i]] <- tibble(
      year = years, 
      deployment = deployment,
      trajectory_id = trajectory_id
    )
  }
  
  # Combine trajectories
  combined_trajectories <- bind_rows(all_trajectories)
  
  # For each year, calculate trajectory quantiles
  summary_stats <- combined_trajectories %>%
    group_by(year) %>%
    summarise(
      p05 = quantile(deployment, 0.05, na.rm = TRUE),
      p25 = quantile(deployment, 0.25, na.rm = TRUE),
      median = quantile(deployment, 0.5, na.rm = TRUE),
      p75 = quantile(deployment, 0.75, na.rm = TRUE),
      p95 = quantile(deployment, 0.95, na.rm = TRUE)
    )
  
  # Store the original parameter quantiles
  original_parameters <- tibble(
    statistic = c("p05", "p25", "median", "p75", "p95"),
    k = as.numeric(params$k_quantiles),
    G = as.numeric(params$G_quantiles),
    L = as.numeric(params$L_quantiles)
  )
  
  # Area-based matching to find representative trajectories for each quantile
  # This measures total area difference between each trajectory and each summary quantile
  
  # Function to calculate area difference between a trajectory and a quantile line
  calculate_area_difference <- function(traj_id, quantile_name) {
    # Extract the trajectory data
    trajectory_data <- combined_trajectories %>%
      filter(trajectory_id == traj_id) %>%
      select(year, deployment) %>%
      arrange(year)
    
    # Extract the quantile data
    quantile_data <- summary_stats %>%
      select(year, !!sym(quantile_name)) %>%
      rename(quantile_value = !!sym(quantile_name))
    
    # Join data
    comparison <- trajectory_data %>%
      left_join(quantile_data, by = "year")
    
    # Calculate area (sum of absolute differences)
    area_diff <- sum(abs(comparison$deployment - comparison$quantile_value))
    
    return(area_diff)
  }
  
  # Find representative trajectories for each quantile using area-based matching
  quantile_names <- c("p05", "p25", "median", "p75", "p95")
  representative_parameters <- tibble(
    statistic = quantile_names,
    trajectory_id = NA_integer_,
    k = NA_real_,
    G = NA_real_,
    L = NA_real_,
    area_difference = NA_real_
  )
  
  # For each quantile, find trajectory with minimum area difference
  for (i in 1:length(quantile_names)) {
    quant <- quantile_names[i]
    
    # Calculate area difference for all trajectories
    area_diffs <- sapply(unique(valid_sets$trajectory_id), 
                         function(tid) calculate_area_difference(tid, quant))
    
    # Find trajectory with minimum area difference
    best_idx <- which.min(area_diffs)
    best_traj_id <- unique(valid_sets$trajectory_id)[best_idx]
    
    # Get parameters for this trajectory
    best_params <- valid_sets %>%
      filter(trajectory_id == best_traj_id)
    
    # Update representative parameters
    representative_parameters$trajectory_id[i] <- best_traj_id
    representative_parameters$k[i] <- best_params$k
    representative_parameters$G[i] <- best_params$G
    representative_parameters$L[i] <- best_params$L
    representative_parameters$area_difference[i] <- area_diffs[best_idx]
  }
  
  # Create trajectories using the representative parameters
  representative_trajectories <- list()
  for (i in 1:nrow(representative_parameters)) {
    k <- representative_parameters$k[i]
    G <- representative_parameters$G[i]
    L <- representative_parameters$L[i]
    stat <- representative_parameters$statistic[i]
    
    # Calculate inflection point
    t0 <- ((1/k) * log((L/yi) - 1)) + ti
    
    # Generate trajectory
    years <- seq(ti, last_year, by=1)
    deployment <- L / (1 + exp(-k * (years - t0)))
    
    representative_trajectories[[i]] <- tibble(
      year = years, 
      deployment = deployment,
      statistic = stat
    )
  }
  
  # Combine representative trajectories
  combined_representative_trajectories <- bind_rows(representative_trajectories)
  
  # Return comprehensive results
  return(list(
    summary = summary_stats,
    original_parameters = original_parameters,
    representative_parameters = representative_parameters,
    parameter_map = valid_sets,
    representative_trajectories = combined_representative_trajectories
  ))
}

#' Make projections using a trained model and empirical data
#'
#' @param model Trained model from build_unified_projection_model
#' @param empirical_data Data frame with empirical national statistics
#' @param last_year Final year for projection
#' @return List with projection results

make_projections <- function(model, empirical_data, last_year = 2050) {
  # Validate empirical data
  required_cols <- model$empirical_features
  missing_cols <- setdiff(required_cols, names(empirical_data))
  
  if (length(missing_cols) > 0) {
    stop(paste("Empirical data is missing required columns:", 
               paste(missing_cols, collapse = ", ")))
  }
  
  # Make projections
  message("Generating projections...")
  projected_trajectory <- project_trajectory(
    empirical_data,
    model,
    last_year = last_year
  )
  
  # Create visualization
  projection_plot <- ggplot(projected_trajectory$summary, aes(x = year)) +
    geom_ribbon(aes(ymin = p05, ymax = p95), fill = "lightblue", alpha = 0.3) +
    geom_ribbon(aes(ymin = p25, ymax = p75), fill = "lightblue", alpha = 0.5) +
    geom_line(aes(y = median), color = "darkblue", size = 1) +
    geom_vline(xintercept = empirical_data$Year, linetype = "dashed") +
    geom_point(data = tibble(year = empirical_data$Year, y = empirical_data$global_deployment),
               aes(x = year, y = y), color = "red", size = 3) +
    labs(
      title = "Technology Diffusion Projection",
      subtitle = "Based on unified model trained on diverse parameter distributions",
      x = "Year",
      y = "Global Deployment (share)"
    ) +
    theme_minimal()
  
  return(list(
    projections = projected_trajectory,
    projection_plot = projection_plot
  ))
}

#--------------------
# 7. Utility Functions
#--------------------

#' Process market data
#' 
#' @param global_data Global market data
#' @param national_data National market data
#' @return List with processed market data

process_market_data <- function(global_data, national_data) {
  # Calculate market shares
  national_shares <- national_data %>%
    mutate(share = Value / sum(Value)) %>%
    arrange(desc(share))
  
  list(
    global_total = sum(global_data$Value),
    market_sizes = national_shares$share
  )
}

prepare_empirical_data <- function(national_params, global_data) {
  # Calculate the required statistics in the same format as curtailed_params
  empirical_stats <- national_params %>%
    group_by(Year) %>%
    summarise(
      k_median = median(K),
      k_q1 = quantile(K, 0.25),
      k_q3 = quantile(K, 0.75),
      L_median = median(L),
      L_q1 = quantile(L, 0.25),
      L_q3 = quantile(L, 0.75),
      G_median = median(G),
      G_q1 = quantile(G, 0.25),
      G_q3 = quantile(G, 0.75),
      n_mature_countries = n(),
      mature_market_share = sum(market_size),
      .groups = 'drop'
    ) %>%
    left_join(global_data %>% select(Year, global_deployment), by = "Year")
  
  return(empirical_stats)
}

calculate_mature_country_stats <- function(countries_data, maturity_threshold = 0.5) {
  # Filter for mature countries
  mature_countries <- countries_data %>%
    dplyr::filter(Maturity >= maturity_threshold)
  
  # Return statistics
  return(tibble::tibble(
    n_mature_countries = nrow(mature_countries),
    mature_market_share = sum(mature_countries$market_size, na.rm = TRUE)
  ))
}


#' @param multi_config_results Results from multi-configuration simulations
#' @param cutoff_years Years to use as cutoff points
#' @param runs_per_config Number of runs to select per configuration
#' @param debug Enable detailed logging
#' @return List with training data and selection metadata
prepare_efficient_training_data_simplified <- function(multi_config_results, cutoff_years, 
                                                       runs_per_config = 75,
                                                       debug = TRUE) {
  # Setup timing to measure performance
  start_time <- Sys.time()
  
  # Step 1: Select diverse subset of runs
  if (debug) message("Starting diversity selection process")
  
  # Get all configurations
  all_configs <- unique(multi_config_results$global_results$config_id)
  
  # Create container for selected runs
  selected_runs_all <- list()
  
  # Process each configuration separately
  for (config_id in all_configs) {
    if (debug) message(sprintf("\nProcessing configuration %s", config_id))
    
    # Extract data for this configuration
    config_data <- list(
      global_results = multi_config_results$global_results %>% 
        filter(config_id == !!config_id),
      country_results = multi_config_results$country_results %>% 
        filter(config_id == !!config_id)
    )
    
    # Count runs for this configuration
    config_runs <- unique(config_data$global_results$run)
    n_config_runs <- length(config_runs)
    
    if (debug) message(sprintf("Configuration %s has %d total runs", 
                               config_id, n_config_runs))
    
    # Select diverse subset - use original function with minimal changes
    n_clusters <- min(runs_per_config, n_config_runs)
    
    if (debug) message(sprintf("Selecting %d diverse runs from configuration %s", 
                               n_clusters, config_id))
    
    selected_runs <- select_diverse_subset_simple(
      all_runs = config_data$country_results,
      n_clusters = n_clusters,
      debug = debug
    )
    
    # Store results
    selected_runs_all[[as.character(config_id)]] <- selected_runs
    
    # Force garbage collection to free memory
    gc(verbose = FALSE)
  }
  
  # Flatten results for easier use
  all_selected_runs <- unlist(selected_runs_all)
  
  if (debug) {
    message(sprintf("\nSelected %d total runs across %d configurations", 
                    length(all_selected_runs), length(all_configs)))
  }
  
  # Step 2: Create filtered version of results with only selected runs
  if (debug) message("\nFiltering results to include only selected runs")
  
  filtered_results <- list(
    global_results = multi_config_results$global_results %>% 
      filter(run %in% all_selected_runs),
    country_results = multi_config_results$country_results %>% 
      filter(run %in% all_selected_runs),
    config_features = multi_config_results$config_features # Pass through unchanged
  )
  
  if (debug) {
    message(sprintf("Filtered from original dataset:"))
    message(sprintf("Global results: %d -> %d rows (%.1f%%)", 
                    nrow(multi_config_results$global_results), 
                    nrow(filtered_results$global_results),
                    100 * nrow(filtered_results$global_results) / 
                      nrow(multi_config_results$global_results)))
    message(sprintf("Country results: %d -> %d rows (%.1f%%)", 
                    nrow(multi_config_results$country_results), 
                    nrow(filtered_results$country_results),
                    100 * nrow(filtered_results$country_results) / 
                      nrow(multi_config_results$country_results)))
  }
  
  # Step 3: Run the regular training data preparation on filtered subset
  if (debug) message("\nPreparing training data from filtered results")
  
  training_data <- prepare_unified_training_data(
    multi_config_results = filtered_results,
    cutoff_years = cutoff_years,
    debug_mode = debug
  )
  
  # Report total time
  end_time <- Sys.time()
  if (debug) {
    message(sprintf("\nTotal processing time: %.1f seconds", 
                    as.numeric(difftime(end_time, start_time, units = "secs"))))
  }
  
  return(list(
    training_data = training_data,
    diverse_selection = list(
      selected_runs = all_selected_runs,
      by_config = selected_runs_all
    )
  ))
}

#' Select diverse subset with simplified approach
#'
#' This function selects a diverse subset of runs using a simpler approach
#' closer to the original code, but with some performance optimizations.
#'
#' @param all_runs Data frame with run data
#' @param n_clusters Number of clusters to create
#' @param max_features Maximum number of features to use (for dimensionality)
#' @param debug Enable detailed logging
#' @return Vector of selected run IDs

select_diverse_subset_simple <- function(all_runs, n_clusters = 50, 
                                         max_features = 10,
                                         debug = FALSE) {
  # Get unique run IDs
  run_ids <- unique(all_runs$run)
  n_total_runs <- length(run_ids)
  
  # Adjust n_clusters if fewer runs than requested clusters
  if (n_total_runs < n_clusters) {
    if (debug) message(sprintf("Only %d runs available, reducing clusters to match", 
                               n_total_runs))
    n_clusters <- n_total_runs
  }
  
  if (debug) message(sprintf("Processing %d runs for diversity analysis", n_total_runs))
  
  # Extract feature vector for each run
  feature_matrix <- matrix(NA, nrow = n_total_runs, ncol = 0)
  feature_names <- NULL
  
  # Process in smaller chunks to avoid progress message overload
  report_step <- max(1, floor(n_total_runs / 10))
  
  for (i in seq_along(run_ids)) {
    if (debug && i %% report_step == 0) 
      message(sprintf("Processing run %d of %d (%.1f%%)", 
                      i, n_total_runs, 100 * i / n_total_runs))
    
    run_data <- all_runs %>% filter(run == run_ids[i])
    feature_vector <- create_run_feature_vector_simple(run_data)
    
    # Store feature names from first run
    if (i == 1) {
      feature_names <- names(feature_vector)
      feature_matrix <- matrix(NA, nrow = n_total_runs, ncol = length(feature_vector))
      colnames(feature_matrix) <- feature_names
    }
    
    feature_matrix[i, ] <- as.numeric(feature_vector)
  }
  
  # Handle missing values
  feature_matrix[is.infinite(feature_matrix)] <- NA
  
  # Check if enough complete cases to proceed
  complete_rows <- complete.cases(feature_matrix)
  if (sum(complete_rows) < 10) {
    warning("Too few complete cases for clustering, returning random subset")
    return(sample(run_ids, min(n_clusters, length(run_ids))))
  }
  
  # Use only complete cases for clustering
  reduced_feature_matrix <- feature_matrix[complete_rows, ]
  reduced_run_ids <- run_ids[complete_rows]
  
  if (debug) message(sprintf("Using %d complete runs for clustering", 
                             nrow(reduced_feature_matrix)))
  
  # Normalize features - safer method
  feature_matrix_scaled <- scale(reduced_feature_matrix)
  
  # Handle columns with no variance (all NaN after scaling)
  na_cols <- colSums(is.na(feature_matrix_scaled)) > 0
  if (any(na_cols)) {
    if (debug) message(sprintf("Removing %d columns with no variance", sum(na_cols)))
    feature_matrix_scaled <- feature_matrix_scaled[, !na_cols, drop = FALSE]
    
    # If all columns were removed, return random subset
    if (ncol(feature_matrix_scaled) == 0) {
      warning("No valid features for clustering, returning random subset")
      return(sample(reduced_run_ids, min(n_clusters, length(reduced_run_ids))))
    }
  }
  
  # Apply PCA but limit features to avoid excessive dimensionality
  pca_result <- tryCatch({
    prcomp(feature_matrix_scaled)
  }, error = function(e) {
    warning("PCA failed, using original features")
    NULL
  })
  
  if (!is.null(pca_result)) {
    # Limit to max_features components
    components_to_keep <- min(max_features, ncol(feature_matrix_scaled))
    reduced_features <- pca_result$x[, 1:components_to_keep, drop = FALSE]
    
    if (debug) message(sprintf("Using %d principal components", components_to_keep))
  } else {
    # Use all features but cap at max_features
    reduced_features <- feature_matrix_scaled[, 1:min(ncol(feature_matrix_scaled), 
                                                      max_features), drop = FALSE]
  }
  
  # Apply k-means clustering with simplified parameters
  k_means_result <- tryCatch({
    kmeans(reduced_features, 
           centers = min(n_clusters, nrow(reduced_features)-1),
           iter.max = 20,  # Limit iterations for speed
           nstart = 5)     # Fewer starts for speed
  }, error = function(e) {
    warning("K-means clustering failed, returning random subset")
    return(sample(reduced_run_ids, min(n_clusters, length(reduced_run_ids))))
  })
  
  # Select medoid from each cluster
  selected_runs <- character(n_clusters)
  valid_count <- 0
  
  for (i in 1:max(k_means_result$cluster)) {
    cluster_members <- which(k_means_result$cluster == i)
    
    if (length(cluster_members) == 0) {
      next  # Skip empty clusters
    } else if (length(cluster_members) == 1) {
      # Only one member - select it
      valid_count <- valid_count + 1
      selected_runs[valid_count] <- reduced_run_ids[cluster_members]
    } else {
      # Multiple members - select the one closest to centroid
      cluster_center <- k_means_result$centers[i, ]
      distances <- apply(reduced_features[cluster_members, , drop = FALSE], 1, 
                         function(x) sum((x - cluster_center)^2))
      closest_idx <- cluster_members[which.min(distances)]
      valid_count <- valid_count + 1
      selected_runs[valid_count] <- reduced_run_ids[closest_idx]
    }
  }
  
  # Remove unused slots
  if (valid_count < length(selected_runs)) {
    selected_runs <- selected_runs[1:valid_count]
  }
  
  if (debug) message(sprintf("Selected %d diverse runs", length(selected_runs)))
  
  return(selected_runs)
}

#' Create run feature vector with simplified calculations
#'
#' @param run_data Data frame with run data
#' @return Named numeric vector of features

create_run_feature_vector_simple <- function(run_data) {
  # Extract basic country features
  country_features <- run_data %>%
    group_by(country) %>%
    summarize(
      market_size = first(market_size),
      k = first(k),
      L = first(L),
      takeoff_year = first(takeoff_year),
      t0 = first(t0),
      max_deployment = max(deployment),
      .groups = "drop"
    ) %>%
    arrange(desc(market_size))
  
  # Check if we have minimal data
  if (nrow(country_features) < 5) {
    # Return a vector of NAs with the right names
    feature_names <- c(
      "large_country_avg_k", "large_country_avg_L", "large_country_avg_takeoff",
      "corr_size_k", "corr_size_L", "corr_size_takeoff",
      "large_vs_small_k_ratio", "large_vs_small_L_ratio",
      "takeoff_spread", "leading_market_share",
      "cv_k", "cv_L", "global_max_deployment"
    )
    return(rep(NA_real_, length(feature_names)))
  }
  
  # Create size categories
  country_features <- country_features %>%
    mutate(
      size_category = ntile(market_size, 3)  # Small, medium, large countries
    )
  
  country_features$size_category <- factor(country_features$size_category, 
                                           labels = c("Small", "Medium", "Large"))
  
  # Calculate relationship metrics
  # Size-parameter correlations
  corr_size_k <- tryCatch(
    cor(country_features$market_size, country_features$k, method = "spearman"),
    error = function(e) NA_real_
  )
  
  corr_size_L <- tryCatch(
    cor(country_features$market_size, country_features$L, method = "spearman"),
    error = function(e) NA_real_
  )
  
  corr_size_takeoff <- tryCatch(
    cor(country_features$market_size, country_features$takeoff_year, method = "spearman"),
    error = function(e) NA_real_
  )
  
  # Growth pattern by size category
  large_countries <- filter(country_features, size_category == "Large")
  small_countries <- filter(country_features, size_category == "Small")
  
  large_k_mean <- if(nrow(large_countries) > 0) mean(large_countries$k, na.rm = TRUE) else NA_real_
  small_k_mean <- if(nrow(small_countries) > 0) mean(small_countries$k, na.rm = TRUE) else NA_real_
  
  large_L_mean <- if(nrow(large_countries) > 0) mean(large_countries$L, na.rm = TRUE) else NA_real_
  small_L_mean <- if(nrow(small_countries) > 0) mean(small_countries$L, na.rm = TRUE) else NA_real_
  
  # Ratios
  large_vs_small_k_ratio <- if(!is.na(large_k_mean) && !is.na(small_k_mean) && small_k_mean != 0) {
    large_k_mean / small_k_mean
  } else {
    NA_real_
  }
  
  large_vs_small_L_ratio <- if(!is.na(large_L_mean) && !is.na(small_L_mean) && small_L_mean != 0) {
    large_L_mean / small_L_mean
  } else {
    NA_real_
  }
  
  # Dispersion metrics
  takeoff_range <- tryCatch(
    diff(range(country_features$takeoff_year, na.rm = TRUE)),
    error = function(e) NA_real_
  )
  
  cv_k <- tryCatch({
    mean_k <- mean(country_features$k, na.rm = TRUE)
    if (is.na(mean_k) || mean_k == 0) NA_real_ else sd(country_features$k, na.rm = TRUE) / mean_k
  }, error = function(e) NA_real_)
  
  cv_L <- tryCatch({
    mean_L <- mean(country_features$L, na.rm = TRUE)
    if (is.na(mean_L) || mean_L == 0) NA_real_ else sd(country_features$L, na.rm = TRUE) / mean_L
  }, error = function(e) NA_real_)
  
  # Leading market share
  leading_countries <- country_features %>%
    arrange(takeoff_year) %>%
    head(3)
  
  leading_market_share <- sum(leading_countries$market_size, na.rm = TRUE) / 
    sum(country_features$market_size, na.rm = TRUE)
  
  # Calculate weighted averages for large countries
  if (nrow(large_countries) > 0) {
    large_country_avg_k <- weighted.mean(large_countries$k, large_countries$market_size, na.rm = TRUE)
    large_country_avg_L <- weighted.mean(large_countries$L, large_countries$market_size, na.rm = TRUE)
    large_country_avg_takeoff <- weighted.mean(large_countries$takeoff_year, large_countries$market_size, na.rm = TRUE)
  } else {
    large_country_avg_k <- NA_real_
    large_country_avg_L <- NA_real_
    large_country_avg_takeoff <- NA_real_
  }
  
  # Global deployment metrics - simplified approach
  has_global_deployment <- any(grepl("global_deployment", names(run_data)))
  
  if (has_global_deployment) {
    global_col <- grep("global_deployment", names(run_data), value = TRUE)[1]
    global_max_deployment <- max(run_data[[global_col]], na.rm = TRUE)
  } else {
    # Simplified calculation
    global_max_deployment <- sum(country_features$max_deployment * country_features$market_size, na.rm = TRUE)
  }
  
  # Create feature vector with minimal set of features
  feature_vector <- c(
    large_country_avg_k = large_country_avg_k,
    large_country_avg_L = large_country_avg_L,
    large_country_avg_takeoff = large_country_avg_takeoff,
    corr_size_k = corr_size_k,
    corr_size_L = corr_size_L,
    corr_size_takeoff = corr_size_takeoff,
    large_vs_small_k_ratio = large_vs_small_k_ratio,
    large_vs_small_L_ratio = large_vs_small_L_ratio,
    takeoff_spread = takeoff_range,
    leading_market_share = leading_market_share,
    cv_k = cv_k,
    cv_L = cv_L,
    global_max_deployment = global_max_deployment
  )
  
  return(feature_vector)
}