library(ggplot2)

## PARAM
METRICS <- c("LRR_SD", "LRR_mean", "LRR_median",
             "BAF_SD", "GCWF",
             "abs_WF", "BAF_DRIFT",
             "n_cnv_raw", "nbp_cnv_raw", 
             "n_del_raw", "nbp_del_raw",
             "n_dup_raw","nbp_dup_raw",
             "n_cnv","nbp_cnv",
             "n_del","nbp_del",
             "n_dup","nbp_dup")

Main <- function() {

  # get user input
  ARGS <- commandArgs(trailingOnly=T)
  if (length(ARGS) != 2) {
    cat("intensity_data_plotting.R <in.tsv> <outroot>\n")
    q()
  }
  in.tsv <- ARGS[1]
  outroot <- ARGS[2]

  # read metrics
  df.1 <- read.table(in.tsv, 
                     header=T, sep="\t")

  # if not already done, define absolute value of waviness factor
  df.1$abs_WF <- abs(df.1$WF)

  # init data structs for storing sumstats
  sumstats.grp <- list()
  sumstats.ds <- list()

  # for each metric
  for (metric.i in METRICS) {

    df.1.x <- subset(df.1, is.na(df.1[[metric.i]])==F)

    # plot full density
    pdf(paste0(outroot,".ALL.density.",metric.i,".pdf"))
    plot(density(df.1.x[[metric.i]]), main=metric.i, xlab="", ylab="density")
    dev.off()

    # form sumstats
    sumstats.grp[[metric.i]] <- FormSumstatsDf(df.1.x,
                                               metric.i,
                                               set_col='group')
    sumstats.ds[[metric.i]] <- FormSumstatsDf(df.1.x,
                                              metric.i,
                                              set_col='dataset')

    # for dataset-level sumstats, add in group ID per dataset
    df.x <- unique(df.1.x[,c("group","dataset")])
    colnames(df.x) <- c("group","set")
    sumstats.ds[[metric.i]] <- merge(sumstats.ds[[metric.i]],
                                     df.x,
                                      by='set')

    # sort sumstats by mean
    sumstats.grp[[metric.i]] <- SortSumstatsDf(sumstats.grp[[metric.i]],
                                               col_sortby="mean",
                                               ascending=TRUE)
    sumstats.ds[[metric.i]] <- SortSumstatsDf(sumstats.ds[[metric.i]],
                                              col_sortby="mean",
                                              ascending=TRUE)
 
    # set group col in dataset metrics to factor
    # sumstats.ds[[metric.i]]$group <- factor(sumstats.ds[[metric.i]]$group,
    #                                        levels=unique(sumstats.ds[[metric.i]]$group))

    # plot density + mean for metric (by dataset)
    gg <- ggplot(df.1.x, aes(x=.data[[metric.i]], color=dataset, group=dataset))
    gg <- gg + xlab(metric.i)
    gg <- gg + geom_density() 
    gg <- gg + geom_vline(data=sumstats.ds[[metric.i]], aes(xintercept=mean, color=set),
                          linetype='dashed')
    ggsave(gg, file=paste0(outroot, ".dataset.",metric.i,".density.pdf")) 

    #  plot density + median for metric (by group)
    gg <- ggplot(df.1.x, aes(x=.data[[metric.i]], color=group, group=group))
    gg <- gg + xlab(metric.i)
    gg <- gg + geom_density() 
    gg <- gg + geom_vline(data=sumstats.grp[[metric.i]], aes(xintercept=mean, color=set),
                          linetype='dashed')
    ggsave(gg, file=paste0(outroot, ".group.",metric.i,".density.pdf")) 
    
    # plot mean / sd per dataset for metric
    gg <- ggplot(sumstats.ds[[metric.i]], aes(x=mean, y=set, color=group))
    gg <- gg + geom_pointrange(aes(xmin=lowerbound, xmax=upperbound))
    gg <- gg + ggtitle(metric.i)
    gg <- gg + theme(axis.title.x = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.y = element_text(angle = 45, hjust = 1))
    ggsave(gg, file=paste0(outroot, ".dataset.",metric.i,".pdf")) 

    # plot mean / sd per group for metric
    gg <- ggplot(sumstats.grp[[metric.i]], aes(x=mean, y=set))
    gg <- gg + geom_pointrange(aes(xmin=lowerbound, xmax=upperbound))
    gg <- gg + ggtitle(metric.i)
    gg <- gg + theme(axis.title.x = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.y = element_text(angle = 45, hjust = 1))
    ggsave(gg, file=paste0(outroot,".group.",metric.i,".pdf")) 

  }


loo.logistic <- LogisticLeaveOneOutAnalysis(df.1, "n_cnv_raw", set_col="dataset")
write.table(loo.logistic,
            file=paste0(outroot,".leaveoneout_logistic.tsv"),
            row.names=F, col.names=T, sep="\t", quote=F)

loo.linear <- LinearLeaveOneOutAnalysis(df.1, "n_cnv_raw", set_col="dataset")
write.table(loo.linear,
            file=paste0(outroot,".leaveoneout_linear.tsv"),
            row.names=F, col.names=T, sep="\t", quote=F)

# loo plot (logistic model)
gg <- ggplot(loo.logistic, aes(x=or, y=set)) 
gg <- gg + geom_pointrange(aes(xmin=or_95ci_l, xmax=or_95ci_u))
gg <- gg + ggtitle("leave-one-out CNV burden odds ratio")
gg <- gg + theme(axis.title.x = element_blank(),
                 axis.title.y = element_blank(),
                 axis.text.y = element_text(angle = 45, hjust = 1))
ggsave(gg, file=paste0(outroot,"leaveoneout_logistic.pdf")) 

# loo plot (linear model)
gg <- ggplot(loo.linear, aes(x=est, y=set)) 
gg <- gg + geom_pointrange(aes(xmin=est_95ci_l, xmax=est_95ci_u))
gg <- gg + ggtitle("leave-one-out CNV rate difference")
gg <- gg + theme(axis.title.x = element_blank(),
                 axis.title.y = element_blank(),
                 axis.text.y = element_text(angle = 45, hjust = 1))
ggsave(gg, file=paste0(outroot,"leaveoneout_linear.pdf")) 


  

}

SortSumstatsDf <- function(sumstats,
                           col_sortby="mean",
                           ascending=FALSE) {
  if (ascending == TRUE) {
    sumstats <- sumstats[order(sumstats[[col_sortby]]), ]
  } else {
    sumstats <- sumstats[rev(order(sumstats[[col_sortby]])), ]
  }
  sumstats$set <- factor(sumstats$set,
                         levels=sumstats$set)
  return(sumstats)
}

FormSumstatsDf <- function(df,
                           metric_col,
                           set_col="group") {
  # form sumstats df
  sumstats <- data.frame(group=character(),
                         dataset=character(),
                         set=character(),
                         mean=numeric(),
                         median=numeric(),
                         sd=numeric(),
                         lowerbound=numeric(),
                         upperbound=numeric()
                        )

  # get unique sets
  sets <- unique(sort(df[[set_col]]))
  
  # compute sumstats for each unique set
  for (set.i in sets) {
    df.i <- subset(df, df[[set_col]] == set.i)
    set.i.mean <- mean(df.i[[metric_col]])
    set.i.median <- median(df.i[[metric_col]])
    set.i.sd <- sd(df.i[[metric_col]])
    set.i.lowerbound <- set.i.mean - set.i.sd
    set.i.upperbound <- set.i.mean + set.i.sd
    sumstats <- rbind(sumstats,
                      data.frame(set=set.i,
                                 mean=set.i.mean,
                                 median=set.i.median,
                                 sd=set.i.sd,
                                 lowerbound=set.i.lowerbound,
                                 upperbound=set.i.upperbound)
                     )
  }

  return(sumstats)
}

LogisticLeaveOneOutAnalysis <- function(df,
                                        metric_col,
                                        set_col="dataset") {
  res.df <- data.frame(set=character(),
                       or=numeric(),
                       or_95ci_l=numeric(),
                       or_95ci_u=numeric(),
                       logistic_p=numeric()
                      )
  sets <- unique(sort(df[[set_col]]))
  for (set.i in sets) {
    df$loo <- ifelse(df[[set_col]] == set.i, 1, 0)
    stats.i <- glm(loo ~ df[[metric_col]], data=df, family=binomial)
    summary.i <- summary(stats.i)
    or.i <- exp(summary.i$coefficients[2, 1])
    logistic.p.i <- summary.i$coefficients[2,4]
    or.confint.i <- confint(stats.i)
    or.ci95l.i <- exp(or.confint.i[2, 1])
    or.ci95u.i <- exp(or.confint.i[2, 2])
    
    res.df <-rbind(res.df,
                   data.frame(set=set.i,
                              or=or.i,
                              or_95ci_l=or.ci95l.i,
                              or_95ci_u=or.ci95u.i,
                              logistic_p=logistic.p.i)
                  )
  }
  return(res.df)
}  

LinearLeaveOneOutAnalysis <- function(df,
                                      metric_col,
                                      set_col="dataset") {
  res.df <- data.frame(set=character(),
                       est=numeric(),
                       est_95ci_l=numeric(),
                       est_95ci_u=numeric(),
                       linear_p=numeric()
                      )
  sets <- unique(sort(df[[set_col]]))
  for (set.i in sets) {
    df$loo <- ifelse(df[[set_col]] == set.i, 1, 0)
    stats.i <- lm(df[[metric_col]] ~ loo, data=df)
    summary.i <- summary(stats.i)
    est.i <- summary.i$coefficients[2, 1]
    linear.p.i <- summary.i$coefficients[2,4]
    est.confint.i <- confint(stats.i)
    est.ci95l.i <- est.confint.i[2, 1]
    est.ci95u.i <- est.confint.i[2, 2]
    
    res.df <-rbind(res.df,
                   data.frame(set=set.i,
                              est=est.i,
                              est_95ci_l=est.ci95l.i,
                              est_95ci_u=est.ci95u.i,
                              linear_p=linear.p.i)
                  )
  }
  return(res.df)
}  






if (interactive() ==F) {
  Main()
}
