#!/usr/bin/env Rscript
# Libs -----
source(file.path(usethis::proj_path(), 'src', 'R', 'utils_slim.R'));

# Global ---------
root <- usethis::proj_path();
rcmd <- stringi::stri_isempty(Sys.getenv('RSTUDIO'));
perf <- PkgPerf$new();

# NOTE: OVERWRITE IF IN RSTUDIO
path <- list(
  input = list(
    meta  = file.path(root, 'data', 'meta', 'meta.tsv.gz'),
    anno  = file.path(root, 'data', 'meta', 'anno.tsv.gz'),
    mrna  = file.path(root, 'data', 'omic', 'mrna', 'rsem.tsv.gz')
  ),
  output = list(
    meta  = file.path(root, 'out',  'omic', 'mrna', 'meta.tsv'),
    lcpm  = file.path(root, 'out',  'omic' ,'mrna', 'lcpm_adj.tsv'),
    contr = file.path(root, 'out',  'omic', 'mrna', 'contr.tsv'),
    rout  = file.path(root, 'out',  'omic', 'mrna', 'rout.rds')
  )
);


# CLI Mode
if (rcmd) {
  args <- commandArgs(trailingOnly = TRUE);
  stopifnot(
  '# args != 7:
  Usage: ./mrna_pipeline.R /in/meta.tsv /in/anno.tsv /in/rsem.tsv /out/meta.tsv /out/contr.tsv /out/lcpm_adj.tsv /out/rout.rds' = length(args) == 7
  );
  path <- list(
    input  = args[c(1,2,3)]   %>% as.list() %>% set_names(c('meta','anno','mrna')),
    output = args[c(4,5,6,7)] %>% as.list() %>% set_names(c('meta','lcpm','contr','rout'))
    );
};


#Part 1: Preprocess ------
# Read Data: sample metadata, gene annotation, and rna expression
dm <- dt.read(path$input$meta, key = c('Seq'));
da <- dt.read(path$input$anno, key = c('Gene'),       col.names = c('Gene','Name','Type'));
dt <- dt.read(path$input$mrna, key = c('Seq','Gene'), col.names = c('Seq', 'Gene', 'Cnt'), select = c('seq', 'gene_id', 'expected_count'));

# Subset: metadata
if (!is.null(dm$Size) && !is.integer(dm$Size)) dm[,Size := as.integer(Size)]
dm <- dm[Pt != 'A015' | Stage != 'A',]
dt.refac(dm);


# Merge: sample metadata, gene annotation, and (integer) mrna expression
dt <- dt[, Cnt:= as.integer(Cnt)][dm, on = .(Seq), nomatch = NULL][da, on = .(Gene), nomatch = NULL];
dt.refac(dt);

# Filter: Remove Gene w/o ENSG ID, w/ multiple Name, Cnt != <= min number of samples (2) and remove Name with multiple Gene. Refactor after.
max_sample_nexpr <- uniqueN(dt$Sample) - 2;
dt.filter <- list(
 gene = unique(c(unique(dt$Gene) %>% str_subset(pattern = '^ENSG', negate = TRUE),
                 unique(dt[,.(Gene,Name)])[,.(NumGene = .N, Gene), by = Name][NumGene > 1, as.character(Gene)],
                 dt[,.(MaxCnt = max(Cnt)), by = .(Gene, Sample)][ MaxCnt == 0, .(N=.N), by=Gene][N > max_sample_nexpr, as.character(Gene)])),
 name = unique(dt[,.(Gene,Name)])[,.(NumName = .N, Name), by = Gene][NumName > 1, as.character(Name)]
);
dt <- dt[!Gene %in% dt.filter$gene & !Name %in% dt.filter$name,] %>% dt.refac();
dt <- unique(dt[,SumCnt:=sum(Cnt), by=.(Gene,Spec)][,`:=`(Cnt = SumCnt, Seq = NULL, SumCnt = NULL)]) %>% setkey(Spec,Name) %>% dt.refac();


# Part 2: Create object to hold expression matrices, principle component results, factor levels for making design matrices
ds  <- list(
  mrna  = list(
    cnt = list(dt.x = dcast.data.table(dt[,.(Spec = Spec,  Gene = Gene, Name = Name, Cnt = Cnt)], Gene + Name ~ Spec, value.var = 'Cnt'), dt.y = dt, dx = NULL,  pca = NULL),
    vst = list(dx = NULL,  pca = NULL),
    cpm = list(dx = NULL,  pca = NULL)
  ),
  adj = list(
    cbat  = list(cnt = list(dx = NULL,  pca = NULL), cpm = list(dx = NULL,  pca = NULL), seq = list(dx = NULL,  pca = NULL)),
    lmm   = list(dx = NULL,  pca = NULL),
    sva   = list(cnt = NULL, cpm = NULL, seq = NULL)
  ),
  de   = list(dt = NULL),
  meta = list(dt = dm, dx = copy(unique(dt[,!c('Gene','Name','Type','Cnt')])), lvls = NULL),
  anno = list(dt = da)
);
rm(dt, dm , da);

# Expression Matrices: integer read counts and transformed integer read counts by: vst and log of cpm.
ds$mrna$cnt$dx <- copy(ds$mrna$cnt$dt.x[,!c('Gene','Name')]) %>% setDF(rownames = as.character(ds$mrna$cnt$dt.x$Name)) %>% as.matrix();
ds$mrna$vst$dx <- vst(ds$mrna$cnt$dx);
ds$mrna$cpm$dx <- cpm(ds$mrna$cnt$dx, log = T, prior.count = 5);

# Expression PCs: PCA on expression matrices with PC1 -> PC_N sorted in decreasing explained variance (N = # of Spec values = # columns in expression data.frame)
# all.dt:  data.table of [N, N + 1] consisting of projections of each Spec transcriptome onto each and every PC
# dt:      data.table of [N, 3] consisting of columns: Spec, PC1, PC2
# var.dt:  data.table of [3, N] consisting of rows: Var = Variance Explained by PC, PctVar = Pct of Total Variance Explained by PC, CumVar = Cum of PctVar
# pcs.dt:  data.table of [G, N + 1] which are the loadings for each principle component by Gene with the additional column Gene being the name
ds$mrna$cnt$pca <- run_pca(ds$mrna$cnt$dx);
ds$mrna$vst$pca <- run_pca(ds$mrna$vst$dx);
ds$mrna$cpm$pca <- run_pca(ds$mrna$cpm$dx);

# GLM Factors: Augment metadata with two factor groups for the GLM: Xt (an extraneous group Xt) and Xp (an explanatory group) with levels explicit in meta$lvls.
ds$meta$lvls <- list(
  Xt = ds$meta$dx[order(Batch,Pt,Loc),                 .(X = dsv(Pt,Loc,Batch)), by = .(Batch,Pt,Loc)]$X, # Xt = Pt.Loc.Batch by Sorted by Batch, Pt, Loc
  Xp = c(ds$meta$dx[Stage == 'N',][,                   .(X = dsv(Stage, Dys)),   by = .(Stage,Dys)]$X,    # Xp = N.No, P.No, P.Dys, A.Dys
         ds$meta$dx[Stage != 'N',][order(-Stage,-Dys), .(X = dsv(Stage, Dys)),   by = .(Stage,Dys)]$X)
);
ds$meta$dx[, `:=`(Xt = factor(dsv(Pt,Loc,Batch), levels = ds$meta$lvls$Xt), Xp = factor(dsv(Stage,Dys), levels = ds$meta$lvls$Xp))];

# Part 3: Batch Correction -------
# Combat/Combat-Seq: Run combat on cnt, cpm and combatseq on cnt and run PCA on the output of each.
ds$adj$cbat$cnt$dx  <- run_combat(ds$mrna$cnt$dx, ds$meta$dx$Batch);
ds$adj$cbat$seq$dx  <- run_combat(ds$mrna$cnt$dx, ds$meta$dx$Batch, TRUE);
ds$adj$cbat$cpm$dx  <- run_combat(ds$mrna$cnt$dx, ds$meta$dx$Batch);
ds$adj$cbat$cnt$pca <- run_pca(ds$adj$cbat$cnt$dx);
ds$adj$cbat$seq$pca <- run_pca(ds$adj$cbat$seq$dx);
ds$adj$cbat$cpm$pca <- run_pca(ds$adj$cbat$cpm$dx);

# SVA/SVASeq: run both on cnt but only sva on cpm and
# #run sva on cnt, cpm and svaseq on cnt and insert 'Spec' ID column into SV data.tables for both cnt and cpm SVs but merge SVs only from sva(cnt) with ds$meta$dx
ds$adj$sva$cnt <- run_sva(ds$mrna$cnt$dx, ds$meta$dx, '~ Xt + Xp', ' ~ Xp');
ds$adj$sva$cpm <- run_sva(ds$mrna$cpm$dx, ds$meta$dx, '~ Xt + Xp', ' ~ Xp');
ds$adj$sva$seq <- run_sva(ds$mrna$cnt$dx, ds$meta$dx, '~ Xt + Xp', ' ~ Xp', use_seq = TRUE);
ds$adj$sva$cnt$dt.sv[,Spec := ds$meta$dx$Spec];
ds$adj$sva$cpm$dt.sv[,Spec := ds$meta$dx$Spec];
ds$meta$dx <- ds$meta$dx[ds$adj$sva$cnt$dt.sv, on='Spec']

# Limma
ds$adj$lmm$mm  <- make.mm(eqn.str(qc(Pt, Loc, Xp), keep(names(ds$meta$dx), str_detect, pattern = 'SV_')), data = ds$meta$dx, rm.pre = qc(Pt,Loc,Xp));
ds$adj$lmm$dx  <- rm_batch_eff(ds$mrna$cpm$dx, ds$meta$dx$Batch, design = ds$adj$lmm$mm);
ds$adj$lmm$pca <- run_pca(ds$adj$lmm$dx);
# -------------------


# Part 4: Differential Expression and Contrasts -----------
# DESeq2: Control for Xt = {Patient, Location, Batch, SV_1, SV_2} and test Xp = {Stage, Dys}
ds$de <- c(ds$de, run_deseq2(ds$mrna$cnt$dx, ds$meta$dx, design = eqn.str(qc(Xt, Xp), keep(names(ds$meta$dx), str_detect, pattern = 'SV_')), run_vst = TRUE));

# Contrasts
ds$de$ctrs <- list(
  'M-B' = results(ds$de$dds, contrast = qc(Xp,P.No,N.No), parallel = TRUE)[ds$de$rows,],
  'M-D' = results(ds$de$dds, contrast = qc(Xp,P.Dy,N.No), parallel = TRUE)[ds$de$rows,],
  'M-A' = results(ds$de$dds, contrast = qc(Xp,A.Dy,N.No), parallel = TRUE)[ds$de$rows,],
  'B-D' = results(ds$de$dds, contrast = qc(Xp,P.Dy,P.No), parallel = TRUE)[ds$de$rows,],
  'B-A' = results(ds$de$dds, contrast = qc(Xp,A.Dy,P.No), parallel = TRUE)[ds$de$rows,],
  'D-A' = results(ds$de$dds, contrast = qc(Xp,A.Dy,P.Dy), parallel = TRUE)[ds$de$rows,]
);
ds$de$ctrsh <- map(ds$de$ctrs, ~ lfcShrink(ds$de$dds, res = ., type = 'ashr', parallel = TRUE));

# Create data.tables for regular contrasts and ashr contrasts
ds$de$dt <- rbindlist(list(
  ds$de$ctrs %>%
    map(as_tibble, rownames = 'Name') %>%
    map(set_names, nm = qc(Name, Mean, Log2FC, Log2FC.SE, Stat, Pval, FDR)) %>%
    imap(~ add_column(.x, Ashr  = 'N', .before = 2)) %>%
    imap(~ add_column(.x, Contr = .y,  .before = 3)) %>%
    map(setDT) %>% rbindlist(),
  ds$de$ctrsh %>%
    map(as_tibble, rownames = 'Name') %>%
    map(set_names, nm = qc(Name, Mean, Log2FC, Log2FC.SE, Pval, FDR)) %>%
    imap(~ add_column(.x, Ashr  = 'Y',.before = 2)) %>%
    imap(~ add_column(.x, Contr = .y, .before = 3)) %>%
    imap(~ add_column(.x, Stat  = NA, .before = 7)) %>%
    map(setDT) %>% rbindlist()
  )
);

# Make Contr a Factor and Create Contr.Lvl to indicate the integer level for each Contr
ds$de$dt[, Contr     := factor(Contr, levels = c('M-B','M-D','M-A','B-D','B-A','D-A'))];
ds$de$dt[, Contr.Lvl := as.integer(Contr)];

# Merge Regular and AshR Log2FC
ds$de$dt[,Key2 := paste0(Name,'.',Contr.Lvl)]
setkey(ds$de$dt, Key2);
dt1 <- ds$de$dt[Ashr == 'N',];
dt2 <- ds$de$dt[Ashr == 'Y',];
ds$de$dt <- dt1[dt2[,.(Ash.Log2FC = Log2FC, Ash.Log2FC.SE = Log2FC.SE), by = Key2], on = 'Key2']
ds$de$dt[,`:=`(Ashr = NULL, Key2 = NULL)];
setcolorder(ds$de$dt, qc(Name,Contr.Lvl,Contr,Mean,Stat,Log2FC,Log2FC.SE,Ash.Log2FC,Ash.Log2FC.SE,Pval,FDR));
rm(dt1, dt2);

# Merge GeneID
ds$de$dt <- ds$de$dt[ds$mrna$cnt$dt.x[, Gene, by = Name], on = 'Name', nomatch=NULL] %>% setcolorder(qc(Name, Gene, Contr.Lvl)) %>% dt.refac();

# Part 5: Save ------
dt.write(ds$de$dt,      path$output$contr);
dt.write(ds$meta$dt,    path$output$meta);
dt.write(ds$adj$lmm$dx, path$output$lcpm);
saveRDS(ds, path$output$rout);
