# Combine the product population with the breeding population for the rotation.
# This way we compare the product populations with reference to the parents 

df_pd <- rbind(df_mean, 
               df_dcross, df_dcross25, 
               df_index, df_index25,
               df_pi)

#---- Get environment-specific GVs (env_gvs), and derived rotated slopes based 
#---- rotation of parents and product development populations combined  ----

env_gv_pd <- as.matrix(df_pd[,c(1:k)]) %*% t(covs_tpe)  # Calculate environment-specific GVs
env_gv_pd_cent <- scale(env_gv_pd, scale = FALSE)       # Center the GVs based on the parent means
env_gv_var <- var(env_gv_pd_cent)                       # Get the (co)variance of the GVs between environments           

# --- SVD of variance-covariance matrix ---
svd_decomp <- svd(env_gv_var)                  
U_pd <- svd_decomp$u[, 1:k, drop = FALSE]         
D_pd <- diag(svd_decomp$d[1:k])                   

# Get latent covariates (H) and slopes (S)
H_pd <- U_pd %*% sqrt(D_pd)        # Covariates (PCA rotation)
S_pd <- env_gv_pd_cent %*% U_pd %*% solve(sqrt(D_pd))   # Slopes (PCA rotation)
range(S_pd %*% t(H_pd) - env_gv_pd_cent) # check for equivalence

# ---- Rotate the slopes and covariates around the TPE mean with drotate machinary ----
m <- colMeans(H_pd)
r1 <- as.numeric(m / sqrt(sum(m^2)))
P_perp <- diag(ncol(H_pd)) - r1 %*% t(r1)
ee <- eigen(P_perp, symmetric = TRUE)
R_rest <- ee$vectors[, ee$values > 1e-8, drop = FALSE]
R <- cbind(r1, R_rest)
H_rot <- H_pd %*% R
range(H_rot %*% t(H_rot) - env_gv_var) # check for equivalence

# now gather rotated covariates and slopes
D_pd_dr <- diag(diag(t(H_rot) %*% H_rot))
H_pd_dr <- H_rot %*% solve(sqrt(D_pd_dr))  # new covariates
S_pd_dr <- S_pd %*% R %*% sqrt(D_pd_dr)  # new slopes 
range(S_pd_dr %*% t(H_pd_dr) - env_gv_pd_cent) # check for equivalence


#---- Calculate metrics and get best release for each strategy within each rmsd window ----

gv_df <- data.frame(main_effect = (rowMeans(env_gv_pd)), 
                    rot_rmsd = (sqrt(rowMeans((S_pd_dr[, -1] %*% t(H_pd_dr[, -1]))^2))), 
                    var_raw = apply(env_gv_pd_cent, 1, var),
                    var_std = apply(scale(env_gv_pd_cent), 1, var),
                    strat = df_pd$strat,
                    S_pd_dr)

gv_df <- gv_df[!gv_df$strat == "popImp",]




#---------------------------------------- record performance of strategies


#=====================================================================================#
# 1. Mean of the top five genotypes per strategy × stability percentile (record values)
#======================================================================================#

# Define your rmsd_windows and list of all strategies
strats <- unique(gv_df$strat)

# Calculate thresholds based on parent population RMSD
rmsd_windows <- quantile(pop_df$rot_rmsd, probs = rmsd_percentiles / 100, na.rm = TRUE)


make_top_mean_row <- function(subdf, top = 5) {
  if (nrow(subdf) == 0) {
    # Return a 1-row NA frame with same columns as gv_df (except 'strat' gets set later)
    na_row <- as.data.frame(lapply(gv_df[0, ], function(x) NA))
    na_row$n_top <- 0L
    return(na_row[1, , drop = FALSE])
  }
  # Order by main_effect desc and take top-N (or fewer if not enough)
  ord <- order(subdf$main_effect, decreasing = TRUE)
  topn <- subdf[ord[seq_len(min(length(ord), top))], , drop = FALSE]
  
  # Average numeric columns; keep non-numeric handled later
  num_cols <- vapply(topn, is.numeric, logical(1))
  mean_vec <- colMeans(topn[, num_cols, drop = FALSE], na.rm = TRUE)
  out <- as.data.frame(t(mean_vec), stringsAsFactors = FALSE)
  
  # Ensure all original columns are present (fill non-numeric with NA for now)
  missing_cols <- setdiff(names(gv_df), names(out))
  for (mc in missing_cols) out[[mc]] <- NA
  
  # Reorder columns to match gv_df
  out <- out[, names(gv_df), drop = FALSE]
  
  # Add count actually used
  out$n_top <- nrow(topn)
  out
}


# Build release_top (averaged “top-N”)
results_top <- list()

for (i in seq_along(rmsd_windows)) {
  r <- rmsd_windows[i]
  percentile_label <- paste0(rmsd_percentiles[i])
  
  filtered <- gv_df[gv_df$rot_rmsd < r, , drop = FALSE]
  
  top_means_by_strat <- do.call(rbind, lapply(strats, function(s) {
    subdf <- filtered[filtered$strat == s, , drop = FALSE]
    row_mean <- make_top_mean_row(subdf, top = 5)  # <--- set your top value here
    row_mean$strat <- s
    row_mean$rmsd_percentile <- percentile_label
    row_mean$rmsd_threshold  <- r
    row_mean
  }))
  
  results_top[[percentile_label]] <- top_means_by_strat
}

release_top <- do.call(rbind, results_top)
rownames(release_top) <- NULL

# Put annotation columns at the end
ann_cols <- c("n_top", "rmsd_percentile", "rmsd_threshold")
release_top <- release_top[, c(setdiff(names(release_top), ann_cols), ann_cols)]



# release based on top 5 
release_top$cycle <- cycle
release_top$Rep <- Rep

# Append to the global release collector
release_top_full <- rbind(release_top_full, release_top)



#=====================================================================================#
# 2. Calculate average correlation between environments in PD and PI populations
#======================================================================================#

df_list <- list(mean = df_mean, 
                dcross = df_dcross, dcross25 = df_dcross25,
                index = df_index, index25 = df_index25,
                pi = df_pi)

cor_means <- data.frame(strat = character(),
                        mean_cor = numeric(),
                        cycle = numeric(),
                        Rep = numeric(),
                        stringsAsFactors = FALSE)


# Loop over each dataset
for (name in names(df_list)) {
  mat <- as.matrix(df_list[[name]][, 1:k]) %*% t(covs_tpe)
  cor_mat <- cov2cor(var(mat))
  mean_cor <- mean(cor_mat[upper.tri(cor_mat)])
  
  cor_means <- rbind(cor_means, 
                     data.frame(strat = name, 
                                mean_cor = mean_cor,
                                cycle = cycle,
                                Rep = Rep))
}


cor_stats <- rbind(cor_stats, cor_means)



#=====================================================================================#
# 3. Collect general summaries on PD and PI populations
#======================================================================================#


#---- Collect data on population improvement (general summary) ----#

pop_sum <- data.frame(
  cycle = cycle,
  Rep = Rep,
  main_effect = mean(pop_df$main_effect),
  main_effect_var  = var(pop_df$main_effect),
  pca_op = mean(pop_df$main_effect),
  var_raw = mean(pop_df$var_raw),
  var_scale = mean(pop_df$var_scale),
  pca_rmsd = mean(pop_df$pca_rmsd),
  rot_rmsd = mean(pop_df$rot_rmsd),
  rot_rmsd_var = var(pop_df$rot_rmsd)
)

# Append to global candidate stats collector
pop_stats <- rbind(pop_stats, pop_sum)



#----- Collect data on PD populations (general summary) -----#

# Convert named thresholds to a numeric-friendly data.frame
rmsd_table <- data.frame(
  rmsd_percentile = as.numeric(sub("%", "", names(rmsd_windows))) / 100,
  threshold = as.numeric(rmsd_windows),
  stringsAsFactors = FALSE
)


# Loop over strategies
for (s in unique(gv_df$strat)) {
  
  sub_df <- gv_df[gv_df$strat == s, ]
  
  # Loop over rmsd thresholds
  for (i in seq_len(nrow(rmsd_table))) {
    
    p_val <- rmsd_table$rmsd_percentile[i]
    threshold <- rmsd_table$threshold[i]
    
    filtered <- sub_df[sub_df$rot_rmsd < threshold, ]
    
    # If no individuals pass, return NA 
    if (nrow(filtered) == 0) {
      n_sel <- 0
      mean_main <- NA_real_
      mean_rmsd <- NA_real_
      mean_var <- NA_real_
    } else {
      n_sel <- nrow(filtered)
      mean_main <- mean(filtered$main_effect, na.rm = TRUE)
      mean_rmsd <- mean(filtered$rot_rmsd, na.rm = TRUE)
      mean_var <- mean(filtered$var_std, na.rm = TRUE)
    }
    
    # Append to collector
    pd_stats <- rbind(pd_stats, data.frame(
      cycle = cycle,
      Rep = Rep,
      strat = s,
      rmsd_percentile = p_val,
      n_selected = n_sel,
      mean_main_effect = mean_main,
      mean_rot_rmsd = mean_rmsd,
      mean_var_std = mean_var
    ))
  }
}


