######################################################################################################
# Code for manuscript: Chia et al. Enclosed bird nests driven by predation and thermoregulation
# 2024 November
# This script reads fitted model files and make result plots

## INPUTS
# Species trait/environmental data with climatic PCA "data/all_traits_pca.csv"
# Fitted models with consensus tree (main analysis) "output/model_*.rds"
# Fitted coefficients with 1,000 trees (supplementary analysis) "output/trees_coef_*.rds"

## OUTPUTS
# Coefficients plots: Figure 3, S6
# Partial effect plots: Figure S3-5, S8-10
######################################################################################################

# required library
library(dplyr)
library(ggplot2)
library(ggpubr) # ggarrange

#-------------------------------------------------
# Function: Coefficient plots
#-------------------------------------------------
df_for_coef_plot <- function(coef, group, nest, ci) {
  # coef <- coef$bootstrap[,-17] # if CI from boostraps
  x <- (1 - 0.01*ci)/2
  dt <- data.frame(mean=apply(coef, 2, mean), 
                   lower=apply(coef, 2, quantile, x), 
                   upper=apply(coef, 2, quantile, 1-x))
  dt <- dt[c("npp", "npp:Ground", "npp:Cooperative", "npp:Clutch",
             "PC1", "PC1:Egg", "I(PC2^2)", "PC2", "I(PC2^2):Egg", "PC2:Egg",
             "Ground", "Cooperative", "Clutch", "Egg", "Migration"), ]
  dt$y <- factor(rownames(dt), levels = rownames(dt))
  dt$sig <- ifelse(dt$lower * dt$upper < 0, 0, 1) # whether significant
  dt$group <- factor(group, levels=c("Passerines", "Non-passerines", "All"))
  dt$nest <- nest
  return(dt)
}

plot_coef <- function(dt.plot, xlim, ylab) {
  dodge <- position_dodge(width = .5)
  p <- ggplot(dt.plot) +
    geom_vline(aes(xintercept = 0), color = "black", linetype = 1, linewidth = .2) +
    geom_linerange(aes(xmin = lower, xmax = upper, y = y, color = group),
                   position = dodge, linewidth = 2.3, lineend = "round") +
    scale_color_manual(values = c("#50a9b3", "#b37350", "gray40")) +
    geom_point(aes(x = mean, y = y, color = group, fill = factor(sig)),
               position = dodge, shape = 21, size = 1.8, stroke = 0) +
    scale_fill_manual(values = c("0" = "black", "1" = "white")) +
    scale_y_discrete(limits = rev) +
    coord_cartesian(xlim = xlim) +
    theme_classic() +
    theme(panel.grid = element_blank(),
          text = element_text(size = 12),
          axis.line.y = element_blank(),
          axis.ticks.y = element_blank(),
          axis.title = element_blank(),
          legend.position = "none",
          panel.background = element_rect(fill = 'transparent'),
          plot.background = element_rect(fill = 'transparent', color = NA))
  if (ylab == F) p <- p + theme(axis.text.y = element_blank())
  return(p)
}

#-------------------------------------------------
# Function: Partial effect plots
#-------------------------------------------------
process_dataset <- function(data) {
  rs <- data %>%
    filter(Y+N > 0) %>%
    select(Migration, Egg, Ground, Cooperative, Clutch, npp, PC1, PC2) %>%
    mutate_at(c("Migration","Egg","Ground","Cooperative","Clutch","npp","PC1","PC2"), ~(scale(.) %>% as.vector))
  return(rs)
}

subset_data <- function(data, nest) {
  if (nest == "enclosed") {
    rs <- data %>% mutate(Y = Enclosed>0, N = Open>0 & Enclosed==0) %>% process_dataset
  } else if (nest == "dome") {
    rs <- data %>% mutate(Y = Dome>0, N = Open>0 & Enclosed==0) %>% process_dataset
  } else if (nest == "cavity") {
    rs <- data %>% mutate(Y = Cavity>0, N = Open>0 & Enclosed==0) %>% process_dataset
  }
  return(rs)
}

pred_partial <- function(coef, vm, npp = vm$npp, PC1 = vm$PC1, PC2 = vm$PC2, ground = vm$Ground, coop = vm$Cooperative, 
                         migrate = vm$Migration, mass = vm$Egg, clutch = vm$Clutch) {
  
  sigmoid <- function(x) 1/(1+exp(-x))
  # the dataframe columnes must be in the same order as the coefficient data
  df <- data.frame(1, npp, PC1, PC2, PC2^2, 
                   ground, coop, migrate, mass, clutch, 
                   npp*ground, npp*coop, npp*clutch, 
                   PC1*mass, PC2*mass, (PC2^2)*mass) %>% as.matrix
  rs <- sigmoid(df %*% t(coef))
  mean <- apply(rs, 1, mean)
  upper <- apply(rs, 1, quantile, 0.975)
  lower <- apply(rs, 1, quantile, 0.025)
  return(cbind(mean, upper, lower))
}

plot_curve_single <- function(df, x, y1, y2) {
  mycol <- c("a" = "dodgerblue", "b" = "darkorange")
  dfplot <- data.frame(x = df[,x], y1 = df[,y1][,1], y2 = df[,y2][,1], 
                       y1max = df[,y1][,2], y1min = df[,y1][,3], 
                       y2max = df[,y2][,2], y2min = df[,y2][,3])
  ggplot(dfplot, mapping = aes(x = x)) + 
    geom_ribbon(mapping = aes(ymax = y1max, ymin = y1min, fill = "a"), alpha = 0.15, show.legend = F) +
    geom_ribbon(mapping = aes(ymax = y2max, ymin = y2min, fill = "b"), alpha = 0.15, show.legend = F) +
    geom_line(mapping = aes(y = y1, color = "a"), linewidth = 1) +
    geom_line(mapping = aes(y = y2, color = "b"), linewidth = 1) +
    scale_color_manual(labels = c(y1, y2), values = mycol) +
    scale_fill_manual(values = mycol) +
    coord_cartesian(ylim = c(0,1), expand = F) +
    theme_bw() +
    theme(panel.grid = element_blank(),
          axis.title = element_blank(),
          axis.text.y = element_blank(),
          legend.position = "none",
          legend.background = element_blank(),
          legend.title = element_blank(),
          text = element_text(size = 12))
}

plot_curve_all <- function(nest, ds, coefs) {
  allplot <- list()
  for (i in 1:3) {
    dt <- ds[[i]]
    coef <- coefs[[i]]
    data <- subset_data(dt, nest)
    vm <- data.frame(t(apply(data, 2, mean))) # mean values of the variables
    
    df_npp <- df_pc1 <- df_pc2 <- data %>% select(npp, PC1, PC2)
    df_npp$`Off-ground` <-      pred_partial(coef = coef, vm = vm, npp = df_npp$npp, ground = min(data$Ground))
    df_npp$`Ground` <-          pred_partial(coef = coef, vm = vm, npp = df_npp$npp, ground = max(data$Ground))
    df_npp$`Non-cooperative` <- pred_partial(coef = coef, vm = vm, npp = df_npp$npp, coop = min(data$Cooperative))
    df_npp$`Cooperative` <-     pred_partial(coef = coef, vm = vm, npp = df_npp$npp, coop = max(data$Cooperative))
    df_npp$`Small clutch` <-    pred_partial(coef = coef, vm = vm, npp = df_npp$npp, clutch = quantile(data$Clutch, 0.1))
    df_npp$`Large clutch` <-    pred_partial(coef = coef, vm = vm, npp = df_npp$npp, clutch = quantile(data$Clutch, 0.9))
    df_pc1$`Small mass` <-      pred_partial(coef = coef, vm = vm, PC1 = df_pc1$PC1, mass = quantile(data$Egg, 0.1))
    df_pc1$`Large mass` <-      pred_partial(coef = coef, vm = vm, PC1 = df_pc1$PC1, mass = quantile(data$Egg, 0.9))
    df_pc2$`Small mass` <-      pred_partial(coef = coef, vm = vm, PC2 = df_pc2$PC2, mass = quantile(data$Egg, 0.1))
    df_pc2$`Large mass` <-      pred_partial(coef = coef, vm = vm, PC2 = df_pc2$PC2, mass = quantile(data$Egg, 0.9))
    
    p1 <- plot_curve_single(df_npp, 'npp', 'Ground', 'Off-ground') 
    p2 <- plot_curve_single(df_npp, 'npp', 'Non-cooperative', 'Cooperative')
    p3 <- plot_curve_single(df_npp, 'npp', 'Large clutch', 'Small clutch')
    p4 <- plot_curve_single(df_pc1, 'PC1', 'Small mass', 'Large mass')
    p5 <- plot_curve_single(df_pc2, 'PC2', 'Small mass', 'Large mass')
    
    allplot <- c(allplot, list(p1, p2, p3, p4, p5))
  }
  plot <- ggarrange(plotlist = allplot, ncol = 5, nrow = 3)
  # ggsave(file, plot, width = 8, height = 5, units = "in")
  return(plot)
}

#-------------------------------------------------
# MAIN
#-------------------------------------------------
# Import bird trait data 
dt.all <- read.csv("data/all_traits_pca.csv", row.names=1) # all apseices
dt.np <- dt.all %>% filter(Order!="Passeriformes") # non-passerines
dt.psr <- dt.all %>% filter(Order=="Passeriformes") # passerines
ds <- list(dt.all, dt.np, dt.psr)

# Import data (main analysis with consensus tree)
coef.Ae <- readRDS("output/model_all_enclosed.rds")$bootstrap[,-17]
coef.Ad <- readRDS("output/model_all_dome.rds")$bootstrap[,-17]
coef.Ac <- readRDS("output/model_all_cavity.rds")$bootstrap[,-17]
coef.Ne <- readRDS("output/model_np_enclosed.rds")$bootstrap[,-17]
coef.Nd <- readRDS("output/model_np_dome.rds")$bootstrap[,-17]
coef.Nc <- readRDS("output/model_np_cavity.rds")$bootstrap[,-17]
coef.Pe <- readRDS("output/model_psr_enclosed.rds")$bootstrap[,-17]
coef.Pd <- readRDS("output/model_psr_dome.rds")$bootstrap[,-17]
coef.Pc <- readRDS("output/model_psr_cavity.rds")$bootstrap[,-17]

# # Import data (supplementary analysis with 1,000 trees)
# coef.Ae <- readRDS("output/multitrees_coef_all_enclosed.rds")
# coef.Ad <- readRDS("output/multitrees_coef_all_dome.rds")
# coef.Ac <- readRDS("output/multitrees_coef_all_cavity.rds")
# coef.Ne <- readRDS("output/multitrees_coef_np_enclosed.rds")
# coef.Nd <- readRDS("output/multitrees_coef_np_dome.rds")
# coef.Nc <- readRDS("output/multitrees_coef_np_cavity.rds")
# coef.Pe <- readRDS("output/multitrees_coef_psr_enclosed.rds")
# coef.Pd <- readRDS("output/multitrees_coef_psr_dome.rds")
# coef.Pc <- readRDS("output/multitrees_coef_psr_cavity.rds")

#### Make coefficient plots ####
# Set confidence interval for coefficient plots
ci <- 95

# Get coefficient plot datasets
dt.enclosed.all <- rbind(df_for_coef_plot(coef.Ae, "All", "Enclosed", ci),
                         df_for_coef_plot(coef.Ne, "Non-passerines", "Enclosed", ci),
                         df_for_coef_plot(coef.Pe, "Passerines", "Enclosed", ci))
dt.dome.all <-     rbind(df_for_coef_plot(coef.Ad, "All", "Dome", ci),
                         df_for_coef_plot(coef.Nd, "Non-passerines", "Dome", ci),
                         df_for_coef_plot(coef.Pd, "Passerines", "Dome", ci))
dt.cavity.all <-   rbind(df_for_coef_plot(coef.Ac, "All", "Cavity", ci),
                         df_for_coef_plot(coef.Nc, "Non-passerines", "Cavity", ci),
                         df_for_coef_plot(coef.Pc, "Passerines", "Cavity", ci))
dt.enclosed.all  
dt.dome.all
dt.cavity.all

# Correction of round-end line range for output into 8x6.6in pdf image
adj <- 0.011
dt.enclosed.all$lower <- dt.enclosed.all$lower + adj
dt.enclosed.all$upper <- dt.enclosed.all$upper - adj
dt.dome.all$lower <- dt.dome.all$lower + adj
dt.dome.all$upper <- dt.dome.all$upper - adj
dt.cavity.all$lower <- dt.cavity.all$lower + adj
dt.cavity.all$upper <- dt.cavity.all$upper - adj

# Coefficient plots (6.6 x 8 in)
# xlim <- c(-.3, .45) # main analysis (Figure 3)
xlim <- c(-.3, .3) # supplementary analysis (Figure S6)
p1 <- plot_coef(dt.enclosed.all, xlim, F)
p2 <- plot_coef(dt.dome.all, xlim, F)
p3 <- plot_coef(dt.cavity.all, xlim, F)
ggarrange(p1, p2, p3, ncol = 3, nrow = 1)
# ggsave(filename = "plot50.pdf", width = 8, height = 6.6)

#### Make partial effect plots ####
# Prepare datasets
coefs.enc <- list(coef.Ae, coef.Ne, coef.Pe)
coefs.dom <- list(coef.Ad, coef.Nd, coef.Pd)
coefs.cav <- list(coef.Ac, coef.Nc, coef.Pc)

# Partial effect plots (6 x 10 in)
plot_curve_all("enclosed", ds, coefs.enc) # Figure S3 (main), S8 (supplementary)
plot_curve_all("dome", ds, coefs.dom) # Figure S4 (main), S9 (supplementary)
plot_curve_all("cavity", ds, coefs.cav) # Figure S5 (main), S10 (supplementary)

