In this notebook we inspect how we can use the BNPR fits from phylodyn to model annual growth.

0.0.0.1 Setup

source("Scripts/vaf_dynamics_functions.R")
## 
## Attaching package: 'greta'
## The following objects are masked from 'package:stats':
## 
##     binomial, cov2cor, poisson
## The following objects are masked from 'package:base':
## 
##     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
##     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
##     tapply
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✔ tibble  3.0.4     ✔ dplyr   1.0.2
## ✔ tidyr   1.1.2     ✔ stringr 1.4.0
## ✔ readr   1.4.0     ✔ forcats 0.5.0
## ✔ purrr   0.3.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ✖ dplyr::slice()  masks greta::slice()
## This is bayesplot version 1.7.2
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
##    * Does _not_ affect other ggplot2 plots
##    * See ?bayesplot_theme_set for details on theme setting
## 
## Attaching package: 'cowplot'
## The following object is masked from 'package:ggpubr':
## 
##     get_legend
## 
## Attaching package: 'extraDistr'
## The following objects are masked from 'package:gtools':
## 
##     ddirichlet, rdirichlet
## The following object is masked from 'package:purrr':
## 
##     rdunif
## 
## ---------------------
## Welcome to dendextend version 1.14.0
## Type citation('dendextend') for how to cite the package.
## 
## Type browseVignettes(package = 'dendextend') for the package vignette.
## The github page is: https://github.com/talgalili/dendextend/
## 
## Suggestions and bug-reports can be submitted at: https://github.com/talgalili/dendextend/issues
## Or contact: <tal.galili@gmail.com>
## 
##  To suppress this message use:  suppressPackageStartupMessages(library(dendextend))
## ---------------------
## 
## Attaching package: 'dendextend'
## The following object is masked from 'package:ggpubr':
## 
##     rotate
## The following object is masked from 'package:stats':
## 
##     cutree
## 
## Attaching package: 'ape'
## The following objects are masked from 'package:dendextend':
## 
##     ladderize, rotate
## The following object is masked from 'package:ggpubr':
## 
##     rotate
## Registered S3 method overwritten by 'treeio':
##   method     from
##   root.phylo ape
## ggtree v2.0.4  For help: https://yulab-smu.github.io/treedata-book/
## 
## If you use ggtree in published research, please cite the most appropriate paper(s):
## 
## - Guangchuang Yu, Tommy Tsan-Yuk Lam, Huachen Zhu, Yi Guan. Two methods for mapping and visualizing associated data on phylogeny using ggtree. Molecular Biology and Evolution 2018, 35(12):3041-3043. doi: 10.1093/molbev/msy194
## - Guangchuang Yu, David Smith, Huachen Zhu, Yi Guan, Tommy Tsan-Yuk Lam. ggtree: an R package for visualization and annotation of phylogenetic trees with their covariates and other associated data. Methods in Ecology and Evolution 2017, 8(1):28-36, doi:10.1111/2041-210X.12628
## 
## Attaching package: 'ggtree'
## The following object is masked from 'package:ape':
## 
##     rotate
## The following object is masked from 'package:dendextend':
## 
##     rotate
## The following object is masked from 'package:tidyr':
## 
##     expand
## The following object is masked from 'package:ggpubr':
## 
##     rotate
## 
## Attaching package: 'reghelper'
## The following object is masked from 'package:greta':
## 
##     beta
## The following object is masked from 'package:base':
## 
##     beta
source("Scripts/make_ultrametric.R")

library(parallel)
library(Matrix)
## 
## Attaching package: 'Matrix'
## The following object is masked from 'package:ggtree':
## 
##     expand
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
library(castor)
## Loading required package: Rcpp
library(phangorn)
library(phylodyn)
library(minpack.lm)
INLA:::inla.dynload.workaround()

change_point <- function(x, b0, m1, m2, delta) { 
  b0 + (x*m1) + (sapply(x-delta, function (t) max(0, t)) * m2)
}

plot_tree <- function(obj) { 
  tree_ultra <- obj$tree_ultra
  tree_ultra$edge.length[is.infinite(tree_ultra$edge.length)] <- 0
  tree_ultra$tip.label <- obj$tree$S
  driver_list <- Filter(function(x) length(x) > 5,obj$driver_list)
  if (length(driver_list) > 1) {
    colours <- RColorBrewer::brewer.pal(length(driver_list),"Set3")
  } else {
    colours <- c("black")
  }
  colour_code <- rep("black",length(tree_ultra$tip.label))
  R <- 1:length(driver_list)
  if (length(driver_list) == 0) {
    R <- numeric(0)
  } 
  for (i in R) {
    D <- driver_list[[i]]
    colour_code[tree_ultra$tip.label %in% unique(D)] <- colours[i]
    print(c(names(driver_list)[i],colours[i]))
  }
  print(names(driver_list))
  print(colours)
  plot(tree_ultra,tip.color = colour_code)
}

expand_edge <- function(tree,edge_id) {
  # takes an edge_id and retrieves the tips which are common to this edge
  n_tips <- length(tree$tip.label)
  edge_ids <- c(edge_id)
  done <- F
  while (done == F) {
    new_edge <- tree$edge[,2][tree$edge[,1] %in% edge_id]
    edge_ids <- c(edge_ids,new_edge)
    edge_id <- new_edge
    if (length(edge_id) == 0) {
      done <- T
    }
  }
  edge_ids <- unique(edge_ids)
  edge_ids <- edge_ids[edge_ids <= n_tips]
  return(edge_ids)
}

trim_tree <- function(tree,l) {
  n_tips <- length(tree$tip.label)
  tree$edge.length <- ifelse(
    tree$edge[,2] <= n_tips,
    tree$edge.length - l,
    tree$edge.length
  )
  tree <- drop.tip(
    tree,
    tree$edge[tree$edge.length < 0] %>% sapply(
      function(x) expand_edge(tree,x)) %>% unlist())
  return(tree)
}

bnpr_at_coalescence <- function(bnpr_estimate) {
  time_at_coal <- bnpr_estimate$coal_times
  X <- bnpr_estimate$summary$time
  Y <- bnpr_estimate$summary$quant0.5
  Y025 <- bnpr_estimate$summary$quant0.025
  Y975 <- bnpr_estimate$summary$quant0.975
  Y_log <- log(na.omit(approx(X,Y,xout = time_at_coal,rule = 2)$y))
  values_at_coalescence <- data.frame(
    X = time_at_coal,
    Y = Y_log,
    Y_exp = approx(X,Y,xout = time_at_coal)$y,
    Y025 = approx(X,Y025,xout = time_at_coal)$y,
    Y975 = approx(X,Y975,xout = time_at_coal)$y
  ) %>%
    na.omit()
  return(values_at_coalescence)
}

bnpr_at_all <- function(bnpr_estimate) {
  X <- bnpr_estimate$summary$time
  Y <- bnpr_estimate$summary$quant0.5
  Y025 <- bnpr_estimate$summary$quant0.025
  Y975 <- bnpr_estimate$summary$quant0.975
  Y_log <- log(Y)
  values_at_coalescence <- data.frame(
    X = X,
    Y = Y_log,
    Y_exp = Y,
    Y025 = Y025,
    Y975 = Y975
  ) %>%
    na.omit()
  return(values_at_coalescence)
}

clade_from_mrca <- function(tree,mrca) {
  edge_idx <- which(tree$edge[,1] == mrca)
  continue <- T
  output <- c(which(tree$edge[,2] == mrca))
  while (continue == T) {
    if (length(edge_idx) > 0) {
      output <- c(output,edge_idx)
      new_edge_idx <- c()
      for (idx in edge_idx) {
        nodes <- tree$edge[idx,2]
        new_idxs <- which(tree$edge[,1] == nodes)
        new_edge_idx <- c(new_edge_idx,new_idxs)
      }
      edge_idx <- new_edge_idx
    } else {
      continue <- F
    }
  }
  return(output)
}

build_tree <- function(sub_data,subsample_size=100,
                       detection_threshold=0) {
  sub_af <- sub_data %>% 
    group_by(CloneID) %>% 
    summarise(N = max(NClones),.groups = "drop") %>%
    arrange(CloneID)
  MaxClone <- max(sub_data$CloneID)
  MaxMut <- max(sub_data$MutID)
  sub_af_sparse <- rep(0,MaxClone)
  sub_af_sparse[sub_af$CloneID] <- sub_af$N 
  
  Presence <- sparseMatrix(i = sub_data$CloneID,
                           j = sub_data$MutID,
                           dims = c(MaxClone,MaxMut))
  Presence[cbind(sub_data$CloneID,sub_data$MutID)] <- 1
  Presence_ <- Presence
  
  if (sum(as.logical(sub_af_sparse)) < subsample_size) {
    sub_af_sparse <- rep(1,length(sub_af_sparse))
  }
  
  S <- sample(MaxClone,subsample_size,replace = F,prob = sub_af_sparse)
  af <- sub_af_sparse[S]
  Presence <- Presence[S,]
  Presence <- as.matrix(rbind(Presence,wt=0))
  r_idx <- colSums(Presence)
  Presence <- Presence[,r_idx > 0]
  Presence_ <- Presence_[,r_idx > 0]
  
  dst <- as.matrix(dist(Presence > 0,method = "manhattan"))
  tree <- root(njs(dst),
               outgroup = "wt",
               resolve.root = T,
               edgelabel = T) 
  tree <- drop.tip(tree,"wt")
  # assign clones not in tree to the closest clone in the tree
  # using Hamming distance (which.max(number of identical bits))
  closest_clone <- (Presence %*% t(Presence_)) %>% apply(2,which.max)
  return(list(tree = tree,S = S,af = af,closest_clone = closest_clone))
}

read_clonex <- function(path,d = 1000) {
  x <- read_tsv(path,
                col_names = c("Gen","NClones","CloneID","MutID"),
                col_types = c(col_integer(),col_integer(),col_integer(),col_integer()),
                progress = F) %>%
    mutate(Driver = MutID <= d)
  return(x)
}

1 BNPR

In this section we calculate/load all BNPR trajectories for the trees inferred from Wright-Fisher (WF) simulations. These simulations are done for 6 different fitness levels - 0.005, 0.010, 0.015, 0.020, 0.025 and 0.030 - over 800 generations and with a fixed population size (200,000).

1.1 Loading data and inferring trees

The first step is to build the trees, which we then display. One can immediately see that expansions become increasingly prevalent for higher fitness effects, with quite a few cases of a single clone sweeping all of the population.

N <- 2e5
N_DRIVERS <- 50
all_file_paths <- list.files(
  "hsc_output_bnpr_complete",pattern = "last_generation",
  full.names = T,recursive = T)
all_driver_file_paths <- list.files(
  "hsc_output_bnpr_complete",pattern = "driver",
  full.names = T,recursive = T)

file_name <- "data_output/simulated_trees_trajectories.RDS"
if (file.exists(file_name) == T) {
  tree_traj <- readRDS(file_name)
  trees <- tree_traj[[1]]
  all_driver_trajectories_raw <- tree_traj[[2]]
} else {
  all_driver_trajectories_raw <- all_driver_file_paths %>%
    lapply(function(x) {
      read.csv(x,header = F) %>%
        arrange(V1,V2) %>%
        select(Gen = V1,Count = V2,Clone = V3,MutID = V4)})
  
  trees <- mclapply(
    all_file_paths,
    mc.cores = 2,
    mc.cleanup = T,
    function(path) {
      s <- str_match(path,"hsc_[0-9.]+") %>%
        gsub(pattern = "hsc_",replacement = "") %>%
        as.numeric()
      system(sprintf('echo "%s"', path)) # prints during mclapply by using bash
      x <- read_clonex(path,d = N_DRIVERS) %>%
        subset(Gen == 800) %>%
        subset(MutID != 0)
      x$R <- as.numeric(str_match(path,"[0-9.]+$"))
      x$s <- s
      x$drift_threshold <- 1 / (N * (s))
      
      driver_list <- x %>% 
        subset(Driver == T) %>% 
        select(CloneID,MutID) 
      driver_list_out <- list()
      for (driver_id in unique(driver_list$MutID)) {
        tmp <- driver_list %>%
          subset(MutID == driver_id) %>%
          select(CloneID) %>%
          unlist
        driver_list_out[[as.character(driver_id)]] <- tmp
      }
      tree <- x %>%
        build_tree()
      ultrametric_tree <- make.ultrametric.tree(tree$tree)
      gc()
      list(
        tree = tree,
        tree_ultra = ultrametric_tree,
        driver_list = driver_list_out,
        path = path
      ) %>%
        return
    }
  )
  
  names(trees) <- gsub('/last_generation','',all_file_paths)
  names(all_driver_trajectories_raw) <- gsub('/driver_trajectory','',all_driver_file_paths)
  saveRDS(object = list(trees,all_driver_trajectories_raw),file = file_name)
}

for (tree_name in names(trees)) {
  trees[[tree_name]]$tree_ultra$edge.length <- trees[[tree_name]]$tree_ultra$edge.length * 800
}
par(mfrow = c(4,5),mar = c(2,0.5,1,0.5))
for (tree_name in names(trees)) {
  tree <- trees[[tree_name]]
  fitness <- str_match(tree_name,"[0-9.]+")
  plot(tree$tree_ultra,main = fitness)
}

1.2 Calculating BNPR trajectories for the whole trees

Here we calculate the actual EPS trajectories using BNPR and display them.

file_name <- "data_output/simulated_bnpr_trees.RDS"
if (file.exists(file_name) == T) {
  all_estimates_trees <- readRDS(file_name)
} else {
  all_estimates_trees <- lapply(
    trees,
    function(x) {
      system(sprintf('echo "%s"', x$path)) # prints during mclapply by using bash
      BNPR(x$tree_ultra) %>%
        return
      }
    )
  saveRDS(all_estimates_trees,file_name)
}

1.3 Calculating BNPR trajectories for each clade

file_name <- "data_output/simulated_bnpr_clades.RDS"
if (file.exists(file_name) == T) {
  all_estimates_full <- readRDS(file_name)
  all_estimates <- all_estimates_full[[1]]
  all_estimates_trimmed_5_years <- all_estimates_full[[2]]
  all_estimates_trimmed_10_years <- all_estimates_full[[3]]
} else {
  all_estimates <- list()
  all_estimates_trimmed_5_years <- list()
  all_estimates_trimmed_10_years <- list()
  i <- 1
  for (obj_name in names(trees)) {
    obj <- trees[[obj_name]]
    print(obj_name)
    s <- str_match(obj_name,"[0-9.]+") %>%
      as.numeric
    Tree <- list(tree = obj$tree_ultra)
    Tree$tree$tip.label <- obj$tree$S
    Tree$tree$edge.length[is.infinite(Tree$tree$edge.length)] <- 0
    tree <- Tree$tree
    tree_ <- tree
    tree_$edge.length <- tree_$edge.length / 800
    driver_id_list <- Filter(function(x) length(x) > 5,obj$driver_list)
    all_drivers <- do.call(c,driver_id_list)
    clades <- cut_tree(tree = tree_,depth = 0.1) %>%
      Filter(f = function(x) length(x) >= 5)
    clades_ <- list()
    for (clade in clades) {
      clade_tips <- tree$tip.label[clade]
      for (driver in names(driver_id_list)) {
        if (all(clade_tips %in% driver_id_list[[driver]])) {
          clades_[[length(clades_)+1]] <- list(
            clade = clade,
            driver = driver
          )
        }
      }
    }
    clades <- clades_
    # clades <- names(driver_id_list) %>%
    #   lapply(function(x) {
    #     return(list(clade = which(tree_$tip.label %in% driver_id_list[[x]]),
    #                 driver = x))
    #     })

    if (!(obj_name %in% names(all_estimates))) {
      bnpr_estimates <- lapply(
        clades,
        #mc.cores = 2,
        #mc.cleanup = T,
        function(clad) {
          #tip_in_clade <- which(Tree$tree$tip.label %in% clad)
          tip_in_clade <- clad$clade
          if (length(tip_in_clade) > 4) {
            sub_tree <- keep.tip(Tree$tree,tip_in_clade) %>%
              multi2di()
            if (!is.null(sub_tree)) {
              bnpr_estimates <- NULL
              bnpr_estimate <- BNPR(sub_tree)
              if (!is.null(bnpr_estimate)) {
                list(
                  tree = tree,
                  bnpr = bnpr_estimate,
                  clade = clad$clade,
                  tree = sub_tree,
                  driver = clad$driver) %>%
                  return
              }
            }
          }
        }
      ) %>%
        Filter(f = function(x) !is.null(x))
      all_estimates[[obj_name]] <- bnpr_estimates
    }
    gc()
  }
  
  # trimmed 5 years
  i <- 1
  for (obj_name in names(trees)) {
    print(obj_name)
    obj <- trees[[obj_name]]
    s <- str_match(obj_name,"[0-9.]+") %>%
      as.numeric
    Tree <- list(tree = obj$tree_ultra)
    Tree$tree$tip.label <- obj$tree$S
    Tree$tree <- trim_tree(Tree$tree,l = 50)
    Tree$tree$edge.length[is.infinite(Tree$tree$edge.length)] <- 0
    tree <- Tree$tree
    tree_ <- tree
    tree_$edge.length <- tree_$edge.length / 800
    driver_id_list <- Filter(function(x) length(x) > 5,obj$driver_list)
    all_drivers <- do.call(c,driver_id_list)
    clades <- cut_tree(tree = tree_,depth = 0.1) %>%
      Filter(f = function(x) length(x) >= 5)
    clades_ <- list()
    for (clade in clades) {
      clade_tips <- tree$tip.label[clade]
      for (driver in names(driver_id_list)) {
        if (all(clade_tips %in% driver_id_list[[driver]])) {
          clades_[[length(clades_)+1]] <- list(
            clade = clade,
            driver = driver
          )
        }
      }
    }
    clades <- clades_
    
    if (!(obj_name %in% names(all_estimates_trimmed_5_years))) {
      bnpr_estimates <- lapply(
        clades,
        #mc.cores = 2,
        #mc.cleanup = T,
        function(clad) {
          #tip_in_clade <- which(Tree$tree$tip.label %in% clad)
          tip_in_clade <- clad$clade
          if (length(tip_in_clade) > 4) {
            sub_tree <- keep.tip(Tree$tree,tip_in_clade) %>%
              multi2di()
            if (!is.null(sub_tree)) {
              bnpr_estimates <- NULL
              bnpr_estimate <- BNPR(sub_tree)
              if (!is.null(bnpr_estimate)) {
                list(
                  tree = tree,
                  bnpr = bnpr_estimate,
                  clade = clad$clade,
                  tree = sub_tree,
                  driver = clad$driver) %>%
                  return
              }
            }
          }
        }
      ) %>%
        Filter(f = function(x) !is.null(x))
      all_estimates_trimmed_5_years[[obj_name]] <- bnpr_estimates
    }
    gc()
  }

  # trimmed 10 years
  i <- 1
  for (obj_name in names(trees)) {
    print(obj_name)
    obj <- trees[[obj_name]]
    s <- str_match(obj_name,"[0-9.]+") %>%
      as.numeric
    Tree <- list(tree = obj$tree_ultra)
    Tree$tree$tip.label <- obj$tree$S
    Tree$tree <- trim_tree(Tree$tree,l = 50)
    Tree$tree$edge.length[is.infinite(Tree$tree$edge.length)] <- 0
    tree <- Tree$tree
    tree_ <- tree
    tree_$edge.length <- tree_$edge.length / 800
    driver_id_list <- Filter(function(x) length(x) > 5,obj$driver_list)
    all_drivers <- do.call(c,driver_id_list)
    clades <- cut_tree(tree = tree_,depth = 0.1) %>%
      Filter(f = function(x) length(x) >= 5)
    clades_ <- list()
    for (clade in clades) {
      clade_tips <- tree$tip.label[clade]
      for (driver in names(driver_id_list)) {
        if (all(clade_tips %in% driver_id_list[[driver]])) {
          clades_[[length(clades_)+1]] <- list(
            clade = clade,
            driver = driver
          )
        }
      }
    }
    clades <- clades_
    
    if (!(obj_name %in% names(all_estimates_trimmed_10_years))) {
      bnpr_estimates <- lapply(
        clades,
        #mc.cores = 2,
        #mc.cleanup = T,
        function(clad) {
          #tip_in_clade <- which(Tree$tree$tip.label %in% clad)
          tip_in_clade <- clad$clade
          if (length(tip_in_clade) > 4) {
            sub_tree <- keep.tip(Tree$tree,tip_in_clade) %>%
              multi2di()
            if (!is.null(sub_tree)) {
              bnpr_estimates <- NULL
              bnpr_estimate <- BNPR(sub_tree)
              if (!is.null(bnpr_estimate)) {
                list(
                  tree = tree,
                  bnpr = bnpr_estimate,
                  clade = clad$clade,
                  tree = sub_tree,
                  driver = clad$driver) %>%
                  return
              }
            }
          }
        }
      ) %>%
        Filter(f = function(x) !is.null(x))
      all_estimates_trimmed_10_years[[obj_name]] <- bnpr_estimates
    }
    gc()
  }
  saveRDS(list(all_estimates,all_estimates_trimmed_5_years,all_estimates_trimmed_10_years),file_name)
}

1.4 Timing clones from trees

for (x in names(all_estimates)) {
  estimate <- all_estimates[[x]]
  if (length(estimate) > 0) {
    for (y in 1:length(estimate)) {
      m <- MRCA(estimate[[y]]$tree,estimate[[y]]$clade)
      estimate[[y]]$mutation_timing <- time_mutation(
        estimate[[y]]$tree,m)
      all_estimates[[x]][[y]]$mutation_timing <- estimate[[y]]$mutation_timing
    }
  }
}

2 Modelling BNPR trajectories

Here we model the EPS trajectories using different types of models, namely:

  • A log-linear model
  • A scaled and shifted sigmoidal curve
  • A log-linear model with one changepoint (captures the earlier and later phases of growth)

We also fit the same models to the original WF trajectories one gets for each individual driver. By doing so, we can more accurately verify whether we are able to recapitulate the clonal growth as it is happening with the BNPR trajectories.

2.1 Fitting different models to BNPR trajectory

all_fits <- list()
par(mfrow = c(4,5),mar = c(2,2,1,1))
for (obj_name in names(all_estimates)) {
  bnpr_estimates <- all_estimates[[obj_name]] %>%
    Filter(f = function(x) class(x$bnpr) != "try-error")
  new_bnpr_estimates <- list()
  for (x in bnpr_estimates) {
    bnpr_estimate <- x$bnpr
    vaf_from_clade <- length(x$clade) / 100
    Y <- log(bnpr_estimate$summary$mean)
    Y_ <- Y - min(Y,na.rm = T) + 1e-8
    Y025 <- log(bnpr_estimate$summary$quant0.025)
    Y975 <- log(bnpr_estimate$summary$quant0.975)
    X <- (800-bnpr_estimate$summary$time)
    tmp_df <- data.frame(
      X = X,Y = Y,Y_ = Y_,
      Y025 = Y025,Y975 = Y975,
      W = (Y975 - Y025)^2/16
    ) %>% 
      na.omit()
    R <- range(tmp_df$X)
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y),]
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y025),]
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y975),]
    tmp_df$Y_exp <- exp(tmp_df$Y)
    
    tmp_df$W[is.infinite(tmp_df$W)] <- max(tmp_df$W[!is.infinite(tmp_df$W)])
    linear_estimate <- lm(Y ~ X,w = 1 / tmp_df$W,data = tmp_df)
    m <- linear_estimate$coefficients[2]
    linear_estimate_no_weights <- lm(Y ~ X,data = tmp_df)
    
    if (nrow(tmp_df) >= 10 & max(tmp_df$Y_exp) < 1e100) {
      non_linear_estimate <- nlsLM(
        Y_exp ~ SSlogis(X, Asym, b2, b3),
        start = list(Asym = max(tmp_df$Y_exp),b2 = mean(tmp_df$X),b3 = 10),
        control = nls.control(maxiter = 1000,warnOnly = T),
        data = tmp_df,
        algorithm = "port",
        weights = 1/tmp_df$W)
      non_linear_estimate_vaf <- nls(
        Y_exp / Y_exp[1] * vaf_from_clade ~ SSlogis(X,1,theta2,theta3),
        weights = 1/W,data = tmp_df,
        control = nls.control(warnOnly = T,
                              maxiter = 1000,
                              minFactor = 1 / (1024^16)),
        start = c(theta2 = max(X),theta3 = 5))

      x_earliest <- seq(min(tmp_df$X),min(tmp_df$X) + 150,length.out = 20)
      y_earliest <- approx(tmp_df$X,tmp_df$Y,xout = x_earliest)$y
      
      x_latest <- seq(max(tmp_df$X) - 150,max(tmp_df$X),length.out = 20)
      y_latest <- approx(tmp_df$X,tmp_df$Y,xout = x_latest)$y
      
      latest <- lm(y_latest ~ x_latest)
      earliest <- lm(y_earliest ~ x_earliest)

      if (class(non_linear_estimate) == "try-error") {
        non_linear_estimate <- try(nlsLM(
          Y_exp ~ SSlogis(X, Asym, b2, b3),
          start = list(Asym = max(tmp_df$Y_exp),b2 = max(tmp_df$X),b3 = 10),
          control = nls.control(maxiter = 1000,warnOnly = T),
          data = tmp_df,
          algorithm = "port",
          weights = 1/tmp_df$W))
      }
    } else {
      non_linear_estimate <- NULL
      non_linear_estimate_vaf <- NULL
      earliest <- NULL
      latest <- NULL
    }
    
    opt_fn <- function(par) {
      se <- (tmp_df$Y - change_point(tmp_df$X,par[1],par[2],par[3],par[4]))^2
      return(mean(se / tmp_df$W))
    }
    change_point_regression <- optim(
      par = c(linear_estimate$coefficients[1],
              linear_estimate$coefficients[2],
              0,
              mean(X)),
      method = "L-BFGS-B",
      lower = c(NA,0,NA,min(tmp_df$X) + diff(R) * 0.25),
      upper = c(NA,NA,NA,max(tmp_df$X) - diff(R) * 0.25),
      fn = opt_fn)
    x$linear_estimate <- linear_estimate
    x$linear_estimate_no_weights <- linear_estimate_no_weights
    x$change_point_regression <- change_point_regression
    x$non_linear_estimate <- non_linear_estimate
    x$non_linear_estimate_vaf <- non_linear_estimate_vaf
    x$data <- tmp_df
    x$earliest_growth <- earliest
    x$latest_growth <- latest
    new_bnpr_estimates[[length(new_bnpr_estimates) + 1]] <- x
  }
  
  bnpr_estimates <- new_bnpr_estimates %>%
    Filter(f = function(x) !is.null(x))
  
  all_estimates[[obj_name]] <- bnpr_estimates
  for (bnpr in bnpr_estimates) {
    cpr <- bnpr$change_point_regression
    all_fits[[length(all_fits)+1]] <- data.frame(
      x = bnpr$data$X,y = exp(bnpr$data$Y),
      w = bnpr$data$W,
      obj_name,n = length(all_fits)+1,
      clade_size = length(bnpr$clade))

    # if ((cpr$par[2] + cpr$par[3]) < 100) {
    #     plot_BNPR(bnpr$bnpr,main = sprintf("%s",length(bnpr$clade)))
    #     # if (!is.null(bnpr$non_linear_estimate)) {
    #     #   lines(bnpr$data$X,predict(
    #     #     bnpr$non_linear_estimate,newdata = data.frame(X = bnpr$data$X)),
    #     #     col = "red")
    #     # }
    #     lines(800-bnpr$data$X,
    #           exp(change_point(bnpr$data$X,cpr$par[1],cpr$par[2],cpr$par[3],cpr$par[4])),
    #           col="red")
    #   }
    }
}
## Warning in nls(Y_exp/Y_exp[1] * vaf_from_clade ~ SSlogis(X, 1, theta2,
## theta3), : singular gradient

## Warning in nls(Y_exp/Y_exp[1] * vaf_from_clade ~ SSlogis(X, 1, theta2,
## theta3), : singular gradient

2.2 Fitting different models to BNPR trajectory (5 year trimmed)

par(mfrow = c(4,5),mar = c(3,1,1,1))
for (obj_name in names(all_estimates_trimmed_5_years)) {
  bnpr_estimates <- all_estimates_trimmed_5_years[[obj_name]] %>%
    Filter(f = function(x) class(x$bnpr) != "try-error")
  new_bnpr_estimates <- list()
  for (x in bnpr_estimates) {
    bnpr_estimate <- x$bnpr
    Y <- log(bnpr_estimate$summary$mean)
    Y_ <- Y - min(Y,na.rm = T) + 1e-8
    Y025 <- log(bnpr_estimate$summary$quant0.025)
    Y975 <- log(bnpr_estimate$summary$quant0.975)
    X <- (800-bnpr_estimate$summary$time)
    tmp_df <- data.frame(
      X = X,Y = Y,Y_ = Y_,
      Y025 = Y025,Y975 = Y975,
      W = (Y975 - Y025)^2/16
    ) %>% 
      na.omit() 
    R <- range(tmp_df$X)
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y),]
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y025),]
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y975),]
    tmp_df$Y_exp <- exp(tmp_df$Y)
    
    tmp_df$W[is.infinite(tmp_df$W)] <- max(tmp_df$W[!is.infinite(tmp_df$W)])
    linear_estimate <- lm(Y ~ X,w = 1 / tmp_df$W,data = tmp_df)
    m <- linear_estimate$coefficients[2]
    linear_estimate_no_weights <- lm(Y ~ X,data = tmp_df)
    
    if (nrow(tmp_df) >= 10 & max(tmp_df$Y_exp) < 1e100) {
      non_linear_estimate <- try(nlsLM(
        Y_exp ~ SSlogis(X, Asym, b2, b3),
        start = list(Asym = max(tmp_df$Y_exp),b2 = mean(tmp_df$X),b3 = 10),
        control = nls.control(maxiter = 1000,warnOnly = T),
        data = tmp_df,
        algorithm = "port",
        weights = 1/tmp_df$W))
      x_earliest <- seq(min(tmp_df$X),min(tmp_df$X) + 150,length.out = 20)
      y_earliest <- approx(tmp_df$X,tmp_df$Y,xout = x_earliest)$y
      
      x_latest <- seq(max(tmp_df$X) - 150,max(tmp_df$X),length.out = 20)
      y_latest <- approx(tmp_df$X,tmp_df$Y,xout = x_latest)$y
      
      latest <- lm(y_latest ~ x_latest)
      earliest <- lm(y_earliest ~ x_earliest)

      if (class(non_linear_estimate) == "try-error") {
        non_linear_estimate <- try(nlsLM(
          Y_exp ~ SSlogis(X, Asym, b2, b3),
          start = list(Asym = max(tmp_df$Y_exp),b2 = max(tmp_df$X),b3 = 10),
          control = nls.control(maxiter = 1000,warnOnly = T),
          data = tmp_df,
          algorithm = "port",
          weights = 1/tmp_df$W))
      }
    } else {
      non_linear_estimate <- NULL
      earliest <- NULL
      latest <- NULL
    }
    
    opt_fn <- function(par) {
      se <- (tmp_df$Y - change_point(tmp_df$X,par[1],par[2],par[3],par[4]))^2
      return(mean(se / tmp_df$W))
    }
    change_point_regression <- optim(
      par = c(linear_estimate$coefficients[1],
              linear_estimate$coefficients[2],
              0,
              mean(X)),
      method = "L-BFGS-B",
      lower = c(NA,0,NA,min(tmp_df$X) + diff(R) * 0.25),
      upper = c(NA,NA,NA,max(tmp_df$X) - diff(R) * 0.25),
      fn = opt_fn)
    x$linear_estimate <- linear_estimate
    x$linear_estimate_no_weights <- linear_estimate_no_weights
    x$change_point_regression <- change_point_regression
    x$non_linear_estimate <- non_linear_estimate
    x$data <- tmp_df
    x$earliest_growth <- earliest
    x$latest_growth <- latest
    new_bnpr_estimates[[length(new_bnpr_estimates) + 1]] <- x
  }
  
  bnpr_estimates <- new_bnpr_estimates %>%
    Filter(f = function(x) !is.null(x))
  
  all_estimates_trimmed_5_years[[obj_name]] <- bnpr_estimates
}
## Error in qr.default(.swts * attr(rhs, "gradient")) : 
##   NA/NaN/Inf in foreign function call (arg 1)

2.3 Fitting different models to BNPR trajectory (10 year trimmed)

par(mfrow = c(4,5),mar = c(3,1,1,1))
for (obj_name in names(all_estimates_trimmed_10_years)) {
  bnpr_estimates <- all_estimates_trimmed_10_years[[obj_name]] %>%
    Filter(f = function(x) class(x$bnpr) != "try-error")
  new_bnpr_estimates <- list()
  for (x in bnpr_estimates) {
    bnpr_estimate <- x$bnpr
    Y <- log(bnpr_estimate$summary$mean)
    Y_ <- Y - min(Y,na.rm = T) + 1e-8
    Y025 <- log(bnpr_estimate$summary$quant0.025)
    Y975 <- log(bnpr_estimate$summary$quant0.975)
    X <- (800-bnpr_estimate$summary$time)
    tmp_df <- data.frame(
      X = X,Y = Y,Y_ = Y_,
      Y025 = Y025,Y975 = Y975,
      W = (Y975 - Y025)^2/16
    ) %>% 
      na.omit()
    R <- range(tmp_df$X)
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y),]
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y025),]
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y975),]
    tmp_df$Y_exp <- exp(tmp_df$Y)
    
    tmp_df$W[is.infinite(tmp_df$W)] <- max(tmp_df$W[!is.infinite(tmp_df$W)])
    linear_estimate <- lm(Y ~ X,w = 1 / tmp_df$W,data = tmp_df)
    m <- linear_estimate$coefficients[2]
    linear_estimate_no_weights <- lm(Y ~ X,data = tmp_df)
    
    if (nrow(tmp_df) >= 10 & max(tmp_df$Y_exp) < 1e100) {
      non_linear_estimate <- try(nlsLM(
        Y_exp ~ SSlogis(X, Asym, b2, b3),
        start = list(Asym = max(tmp_df$Y_exp),b2 = mean(tmp_df$X),b3 = 10),
        control = nls.control(maxiter = 1000,warnOnly = T),
        data = tmp_df,
        algorithm = "port",
        weights = 1/tmp_df$W))
      x_earliest <- seq(min(tmp_df$X),min(tmp_df$X) + 150,length.out = 20)
      y_earliest <- approx(tmp_df$X,tmp_df$Y,xout = x_earliest)$y
      
      x_latest <- seq(max(tmp_df$X) - 150,max(tmp_df$X),length.out = 20)
      y_latest <- approx(tmp_df$X,tmp_df$Y,xout = x_latest)$y
      
      latest <- lm(y_latest ~ x_latest)
      earliest <- lm(y_earliest ~ x_earliest)

      if (class(non_linear_estimate) == "try-error") {
        non_linear_estimate <- try(nlsLM(
          Y_exp ~ SSlogis(X, Asym, b2, b3),
          start = list(Asym = max(tmp_df$Y_exp),b2 = max(tmp_df$X),b3 = 10),
          control = nls.control(maxiter = 1000,warnOnly = T),
          data = tmp_df,
          algorithm = "port",
          weights = 1/tmp_df$W))
      }
    } else {
      non_linear_estimate <- NULL
      earliest <- NULL
      latest <- NULL
    }
    
    opt_fn <- function(par) {
      se <- (tmp_df$Y - change_point(tmp_df$X,par[1],par[2],par[3],par[4]))^2
      return(mean(se / tmp_df$W))
    }
    change_point_regression <- optim(
      par = c(linear_estimate$coefficients[1],
              linear_estimate$coefficients[2],
              0,
              mean(X)),
      method = "L-BFGS-B",
      lower = c(NA,0,NA,min(tmp_df$X) + diff(R) * 0.25),
      upper = c(NA,NA,NA,max(tmp_df$X) - diff(R) * 0.25),
      fn = opt_fn)
    x$linear_estimate <- linear_estimate
    x$linear_estimate_no_weights <- linear_estimate_no_weights
    x$change_point_regression <- change_point_regression
    x$non_linear_estimate <- non_linear_estimate
    x$data <- tmp_df
    x$earliest_growth <- earliest
    x$latest_growth <- latest
    new_bnpr_estimates[[length(new_bnpr_estimates) + 1]] <- x
  }
  
  bnpr_estimates <- new_bnpr_estimates %>%
    Filter(f = function(x) !is.null(x))
  
  all_estimates_trimmed_10_years[[obj_name]] <- bnpr_estimates
}
## Error in qr.default(.swts * attr(rhs, "gradient")) : 
##   NA/NaN/Inf in foreign function call (arg 1)

2.3.1 Plotting estimates with distinct trimming

trajectories_trimmed_0 <- names(all_estimates) %>%
  lapply(function(x) do.call(
    rbind,lapply(all_estimates[[x]],
                 function(y) data.frame(y$data,id = x,clade = paste(y$clade,collapse=','),
                                        clade_size = length(y$clade))))) %>%
  do.call(what = rbind) %>%
  mutate(Trimmed = 0)

trajectories_trimmed_5 <- names(all_estimates_trimmed_5_years) %>%
  lapply(function(x) do.call(
    rbind,lapply(all_estimates_trimmed_5_years[[x]],
                 function(y) data.frame(y$data,id = x,clade = paste(y$clade,collapse=','),
                                        clade_size = length(y$clade))))) %>%
  do.call(what = rbind) %>% 
  mutate(Trimmed = 5) 

trajectories_trimmed_10 <- names(all_estimates_trimmed_10_years) %>%
  lapply(function(x) do.call(
    rbind,lapply(all_estimates_trimmed_10_years[[x]],
                 function(y) data.frame(y$data,id = x,clade = paste(y$clade,collapse=','),
                                        clade_size = length(y$clade))))) %>%
  do.call(what = rbind) %>%
  mutate(Trimmed = 10)

rbind(trajectories_trimmed_0,
      trajectories_trimmed_5,
      trajectories_trimmed_10) %>%
  group_by(clade,id) %>%
  filter(all(c(0,5,10) %in% Trimmed)) %>%
  subset(paste(id,clade) %in% sample(unique(paste(id,clade)),40,replace = F)) %>% 
  group_by(clade,id,Trimmed) %>%
  mutate(w_sum = mean(W)) %>% 
  mutate(X = X - min(X)) %>% 
  mutate(Trimmed = c(`0`="Original",`5`="Trimmed by 5 years",`10`="Trimmed by 10 years")[as.character(Trimmed)]) %>%
  mutate(Trimmed = factor(Trimmed,
                          levels = c("Original","Trimmed by 5 years","Trimmed by 10 years"))) %>% 
  ungroup %>% 
  mutate(unique_id = as.character(as.numeric(as.factor(paste(id,clade))))) %>%
  group_by(clade,id) %>% 
  filter(all(w_sum < 10)) %>% 
  ggplot(aes(x = X,y = exp(Y),group = paste(id,clade,Trimmed),colour = Trimmed,size = Trimmed)) + 
  geom_line(alpha = 0.3) + 
  scale_y_continuous(trans = 'log10') + 
  coord_cartesian(ylim = c(10,1e10)) + 
  facet_wrap(~ reorder(sprintf("id=%s; N=%s",unique_id,clade_size),clade_size)) + 
  theme_gerstung(base_size = 6) +
  theme(strip.text = element_text(margin = margin())) + 
  scale_color_aaas(name = NULL) + 
  scale_size_manual(values = c(0.5,1.0,1.5),name = NULL) +
  xlab("Time since first coalescence (generations)") + 
  ylab("Neff") + 
  theme(legend.position = "bottom",legend.box.spacing = unit(0,"cm")) + 
  ggsave("figures/simulations/trimmed_representation.pdf",height = 5,width = 6)

2.4 Fitting different models to BNPR trajectory (only at coalescence)

all_estimates_coalescence <- all_estimates
par(mfrow = c(4,5),mar = c(3,1,1,1))
for (obj_name in names(all_estimates_coalescence)) {
  bnpr_estimates <- all_estimates_coalescence[[obj_name]] %>%
    Filter(f = function(x) class(x$bnpr) != "try-error")
  new_bnpr_estimates <- list()
  for (x in bnpr_estimates) {
    x$linear_estimate <- NULL
    x$linear_estimate_no_weights <- NULL
    x$change_point_regression <- NULL
    x$earliest_growth <- NULL
    x$latest_growth <- NULL
    x$data <- NULL
    tmp_df <- bnpr_at_coalescence(x$bnpr) %>%
      mutate(W = (log(Y975) - log(Y025))^2/16,
             Y_ = Y - min(Y,na.rm = T) + 1e-8,
             X = 800 - X) %>%
      na.omit
    tmp_df <- tmp_df[!is.infinite(tmp_df$Y),]
    tmp_df$W[is.infinite(tmp_df$W)] <- max(tmp_df$W[!is.infinite(tmp_df$W)])
    
    if (nrow(tmp_df) >= 5) {
      linear_estimate <- lm(Y ~ X,w = 1 / tmp_df$W,data = tmp_df)
      m <- linear_estimate$coefficients[2]
      linear_estimate_no_weights <- lm(Y ~ X,data = tmp_df)
      
      if (nrow(tmp_df) >= 10 & max(tmp_df$Y_exp) < 1e100) {
        non_linear_estimate <- try({nlsLM(
          Y_exp ~ SSlogis(X, Asym, b2, b3),
          start = list(Asym = max(tmp_df$Y_exp),b2 = mean(tmp_df$X),b3 = 10),
          control = nls.control(maxiter = 1000,warnOnly = T),
          data = tmp_df,
          algorithm = "port",
          weights = 1/tmp_df$W)})
  
        if (class(non_linear_estimate) == "try-error") {
          non_linear_estimate <- try(nlsLM(
            Y_exp ~ SSlogis(X, Asym, b2, b3),
            start = list(Asym = max(tmp_df$Y_exp),b2 = max(tmp_df$X),b3 = 10),
            control = nls.control(maxiter = 1000,warnOnly = T),
            data = tmp_df,
            algorithm = "port",
            weights = 1/tmp_df$W))
        }
      } else {
        non_linear_estimate <- NULL
      }
      
      opt_fn <- function(par) {
        se <- (tmp_df$Y - change_point(tmp_df$X,par[1],par[2],par[3],par[4]))^2
        return(mean(se / tmp_df$W))
      }
      change_point_regression <- optim(
        par = c(linear_estimate$coefficients[1],
                linear_estimate$coefficients[2],
                0,
                mean(tmp_df$X)),
        method = "L-BFGS-B",
        lower = c(NA,0,NA,min(tmp_df$X)),
        upper = c(NA,NA,NA,max(tmp_df$X)),
        fn = opt_fn)
      x$linear_estimate <- linear_estimate
      x$linear_estimate_no_weights <- linear_estimate_no_weights
      x$change_point_regression <- change_point_regression
      x$non_linear_estimate <- non_linear_estimate
      x$data <- tmp_df
      new_bnpr_estimates[[length(new_bnpr_estimates) + 1]] <- x
    }
  }
  
  bnpr_estimates <- new_bnpr_estimates %>%
    Filter(f = function(x) !is.null(x))
  
  all_estimates_coalescence[[obj_name]] <- bnpr_estimates
}
par(mfrow = c(4,5),mar = c(2,2,1,1))
for (estimate in all_estimates) {
  for (clade in estimate) {
    if (length(clade) > 0) {
      if (!is.null(clade$linear_estimate)) {
        plot_BNPR(clade$bnpr)
        #points(800 - clade$data$X,clade$data$Y_exp,col="red")
        P <- clade$change_point_regression$par
        lines(clade$bnpr$summary$time,
              exp(change_point((800 - clade$bnpr$summary$time),P[1],P[2],P[3],P[4])),
              col = "red")
        }
      }
    }
  }

## Warning in plot.window(...): nonfinite axis limits [GScale(-inf,173.085,2, .);
## log=1]

2.5 Fitting different models to driver trajectories from the original WF simulations

all_driver_fits <- list()
par(mfrow = c(4,5),mar = c(1.5,2.5,2.5,1))
for (x in names(all_driver_trajectories_raw)) {
  driver_trajectories <- all_driver_trajectories_raw[[x]]
  driver_trajectories$ClosestTip <- trees[[x]]$tree$closest_clone[driver_trajectories$Clone]
  driver_trajectories$Clade <- NA
  if (length(all_estimates[[x]]) > 0) {
      for (i in 1:length(all_estimates[[x]])) {
        estimate <- all_estimates[[x]][[i]]
        driver_trajectories$Clade[driver_trajectories$ClosestTip %in% estimate$clade] <- i
        }
  }
  
  driver_trajectories <- driver_trajectories %>%
    group_by(Clade,Gen,MutID) %>%
    summarise(Count = sum(Count))
  clades <- unique(driver_trajectories$Clade)
  clades <- clades[!is.na(clades)]
  all_driver_fits[[x]] <- list()
  for (clade in clades) {
    tmp_df <- driver_trajectories %>%
      subset(Clade == clade)
    drivers <- unique(tmp_df$MutID)
    tmp_df <- tmp_df %>% 
      group_by(Clade,Gen) %>%
      summarise(Count = sum(Count))
    clade <- tmp_df$Clade[1]
    has_last_gen <- any(tmp_df$Gen == 800)
    large_at_last_gen <- ifelse(
      has_last_gen,
      tmp_df$Count[tmp_df$Gen == 800] > 200,
      F)
    
    if (nrow(tmp_df) >= 5 & has_last_gen & large_at_last_gen) {
      x_earliest <- seq(min(tmp_df$Gen),min(tmp_df$Gen) + 50,length.out = 20)
      y_earliest <- approx(tmp_df$Gen,tmp_df$Count,xout = x_earliest)$y
      
      x_latest <- seq(max(tmp_df$Gen) - 50,max(tmp_df$Gen),length.out = 20)
      y_latest <- approx(tmp_df$Gen,tmp_df$Count,xout = x_latest)$y
      
      latest <- lm(y_latest ~ x_latest)
      earliest <- lm(y_earliest ~ x_earliest)

      linear_sol <- lm(log(Count) ~ Gen,data = tmp_df)
      opt_fn <- function(par) {
        se <- (log(tmp_df$Count) - change_point(tmp_df$Gen,par[1],par[2],par[3],par[4]))^2
        return(mean(se))
      }
      change_point_regression <- optim(
        par = c(linear_sol$coefficients[1],linear_sol$coefficients[2],0,mean(tmp_df$Gen)),
        method = "L-BFGS-B",
        lower = c(NA,0,NA,min(tmp_df$Gen) + diff(range(tmp_df$Gen)) * 0.25),
        upper = c(NA,NA,NA,max(tmp_df$Gen) - diff(range(tmp_df$Gen)) * 0.25),
        fn = opt_fn)
      non_linear_estimate <- nlsLM(
        Count ~ SSlogis(Gen, 2e5, b2, b3),
        start = list(b2 = mean(tmp_df$Gen),b3 = 10),
        control = nls.control(maxiter = 1000,warnOnly = T),
        data = tmp_df,
        upper = c(b2 = 3000,b3 = 1e5),
        algorithm = "port")
      last_vaf <- tmp_df$Count[tmp_df$Gen == 800]/2e5
      all_driver_fits[[x]][[as.character(clade)]] <- list(
        data = tmp_df,
        non_linear_estimate = non_linear_estimate,
        change_point_regression = change_point_regression,
        earliest_driver = earliest,
        latest_driver = latest,
        last_vaf = last_vaf,
        drivers = drivers,
        gen_at_onset = min(tmp_df$Gen))
    }
  }
}

2.6 Extracting coefficients from BNPR fits

all_bnpr_fits <- names(all_estimates) %>%
  lapply(
    function(x) {
      if (length(all_estimates[[x]]) > 0) {
        tree_eps <- unlist(all_estimates_trees[[x]]$summary[1,c(4,2,6)])
        lapply(1:length(all_estimates[[x]]),function(y) {
          y <- as.numeric(y)
          if (!is.null(all_estimates[[x]][[y]]$non_linear_estimate)) {
            pars <- all_estimates[[x]][[y]]$non_linear_estimate$m$getAllPars()
            tao <- all_estimates[[x]][[y]]$mutation_timing
            cp <- all_estimates[[x]][[y]]$change_point_regression$par
            pop_size_at_tao_cp <- exp((800 - tao)*cp[2])
            earliest <- all_estimates[[x]][[y]]$earliest_growth$coefficients[2]
            latest <- all_estimates[[x]][[y]]$latest_growth$coefficients[2]
            pars_vaf <- all_estimates[[x]][[y]]$non_linear_estimate_vaf$m$getPars()
            O <- data.frame(
              l = all_estimates[[x]][[y]]$linear_estimate$coefficients[2],
              nl = 1 / pars[3],
              nl_asym = pars[1],
              nl_xmid = pars[2],
              l_nw = all_estimates[[x]][[y]]$linear_estimate_no_weights$coefficients[2],
              cp_1 = cp[2],
              cp_2 = cp[3],
              earliest = earliest,
              latest = latest,
              expected = 1/pars_vaf[2],
              changepoint_bnpr = cp[4],
              time_at_onset_low = tao[1],
              time_at_onset_high = tao[2],
              w_sum = mean(all_estimates[[x]][[y]]$data$W,na.rm=T),
              pop_at_detection_low = pop_size_at_tao_cp[2],
              pop_at_detection_high = pop_size_at_tao_cp[1],
              tree_eps_low = tree_eps[1],
              tree_eps_mean = tree_eps[2],
              tree_eps_high = tree_eps[3],
              point_of_saturation = which(
                c(
                  predict(
                    all_estimates[[x]][[y]]$non_linear_estimate,
                    newdata = data.frame(X = seq(1,800)))) > (0.9 * pars[1])) %>% min,
              name = x,
              clade_size = length(all_estimates[[x]][[y]]$clade),
              driver = all_estimates[[x]][[y]]$driver,
              clade = y)
            return(O)
          }
        }) %>%
          do.call(what = rbind)
      } else {
        return(NULL)
      }
    }
    ) %>%
  do.call(what = rbind) %>%
  mutate(fitness = as.numeric(gsub('_',"",str_match(name,"_[0-9.]+_")))) %>%
  group_by(name) %>%
  mutate(N = length(unique(clade))) %>%
  mutate(cp_late = cp_1 + cp_2)
## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf
all_bnpr_fits_coalescence <- names(all_estimates_coalescence) %>%
  lapply(
    function(x) {
      if (length(all_estimates_coalescence[[x]]) > 0) {
        lapply(1:length(all_estimates_coalescence[[x]]),function(y) {
          y <- as.numeric(y)
          if (!is.null(all_estimates_coalescence[[x]][[y]]$change_point_regression)) {
            cp <- all_estimates_coalescence[[x]][[y]]$change_point_regression$par
            O <- data.frame(
              l = all_estimates_coalescence[[x]][[y]]$linear_estimate$coefficients[2],
              l_nw = all_estimates_coalescence[[x]][[y]]$linear_estimate_no_weights$coefficients[2],
              cp_1 = cp[2],
              cp_2 = cp[3],
              changepoint_bnpr = cp[4],
              w_sum = mean(all_estimates_coalescence[[x]][[y]]$data$W,na.rm=T),
              name = x,
              clade_size = length(all_estimates_coalescence[[x]][[y]]$clade),
              driver = all_estimates_coalescence[[x]][[y]]$driver,
              clade = y)
            return(O)
          }
        }) %>%
          do.call(what = rbind)
      } else {
        return(NULL)
      }
    }
    ) %>%
  do.call(what = rbind) %>%
  mutate(fitness = as.numeric(gsub('_',"",str_match(name,"_[0-9.]+_")))) %>%
  group_by(name) %>%
  mutate(N = length(unique(clade))) %>%
  mutate(cp_late = cp_1 + cp_2)

all_bnpr_fits_trimmed_5 <- names(all_estimates_trimmed_5_years) %>%
  lapply(
    function(x) {
      if (length(all_estimates_trimmed_5_years[[x]]) > 0) {
        lapply(1:length(all_estimates_trimmed_5_years[[x]]),function(y) {
          y <- as.numeric(y)
          nle <- all_estimates_trimmed_5_years[[x]][[y]]$non_linear_estimate
          if (!is.null(nle) & class(nle) != "try-error") {
            pars <- all_estimates_trimmed_5_years[[x]][[y]]$non_linear_estimate$m$getAllPars()
            cp <- all_estimates_trimmed_5_years[[x]][[y]]$change_point_regression$par
            earliest <- all_estimates_trimmed_5_years[[x]][[y]]$earliest_growth$coefficients[2]
            latest <- all_estimates_trimmed_5_years[[x]][[y]]$latest_growth$coefficients[2]
            O <- data.frame(
              l = all_estimates_trimmed_5_years[[x]][[y]]$linear_estimate$coefficients[2],
              nl = 1 / pars[3],
              nl_asym = pars[1],
              nl_xmid = pars[2],
              l_nw = all_estimates_trimmed_5_years[[x]][[y]]$linear_estimate_no_weights$coefficients[2],
              cp_1 = cp[2],
              cp_2 = cp[3],
              earliest = earliest,
              latest = latest,
              changepoint_bnpr = cp[4],
              w_sum = mean(all_estimates_trimmed_5_years[[x]][[y]]$data$W,na.rm=T),
              point_of_saturation = which(
                c(
                  predict(
                    all_estimates_trimmed_5_years[[x]][[y]]$non_linear_estimate,
                    newdata = data.frame(X = seq(1,800)))) > (0.9 * pars[1])) %>% min,
              name = x,
              clade_size = length(all_estimates_trimmed_5_years[[x]][[y]]$clade),
              driver = all_estimates_trimmed_5_years[[x]][[y]]$driver,
              clade = y)
            return(O)
          }
        }) %>%
          do.call(what = rbind)
      } else {
        return(NULL)
      }
    }
    ) %>%
  do.call(what = rbind) %>%
  mutate(fitness = as.numeric(gsub('_',"",str_match(name,"_[0-9.]+_")))) %>%
  group_by(name) %>%
  mutate(N = length(unique(clade))) %>%
  mutate(cp_late = cp_1 + cp_2)
## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf
all_bnpr_fits_trimmed <- names(all_estimates_trimmed_10_years) %>%
  lapply(
    function(x) {
      if (length(all_estimates_trimmed_10_years[[x]]) > 0) {
        lapply(1:length(all_estimates_trimmed_10_years[[x]]),function(y) {
          y <- as.numeric(y)
          nle <- all_estimates_trimmed_10_years[[x]][[y]]$non_linear_estimate
          if (!is.null(nle) & class(nle) != "try-error") {
            pars <- all_estimates_trimmed_10_years[[x]][[y]]$non_linear_estimate$m$getAllPars()
            cp <- all_estimates_trimmed_10_years[[x]][[y]]$change_point_regression$par
            earliest <- all_estimates_trimmed_10_years[[x]][[y]]$earliest_growth$coefficients[2]
            latest <- all_estimates_trimmed_10_years[[x]][[y]]$latest_growth$coefficients[2]
            O <- data.frame(
              l = all_estimates_trimmed_10_years[[x]][[y]]$linear_estimate$coefficients[2],
              nl = 1 / pars[3],
              nl_asym = pars[1],
              nl_xmid = pars[2],
              l_nw = all_estimates_trimmed_10_years[[x]][[y]]$linear_estimate_no_weights$coefficients[2],
              cp_1 = cp[2],
              cp_2 = cp[3],
              earliest = earliest,
              latest = latest,
              changepoint_bnpr = cp[4],
              w_sum = mean(all_estimates_trimmed_10_years[[x]][[y]]$data$W,na.rm=T),
              point_of_saturation = which(
                c(
                  predict(
                    all_estimates_trimmed_10_years[[x]][[y]]$non_linear_estimate,
                    newdata = data.frame(X = seq(1,800)))) > (0.9 * pars[1])) %>% min,
              name = x,
              clade_size = length(all_estimates_trimmed_10_years[[x]][[y]]$clade),
              driver = all_estimates_trimmed_10_years[[x]][[y]]$driver,
              clade = y)
            return(O)
          }
        }) %>%
          do.call(what = rbind)
      } else {
        return(NULL)
      }
    }
    ) %>%
  do.call(what = rbind) %>%
  mutate(fitness = as.numeric(gsub('_',"",str_match(name,"_[0-9.]+_")))) %>%
  group_by(name) %>%
  mutate(N = length(unique(clade))) %>%
  mutate(cp_late = cp_1 + cp_2)
## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

## Warning in min(.): no non-missing arguments to min; returning Inf

2.7 Extracting coefficients from driver fits from the original WF simulations

all_driver_fits_df <- names(all_driver_fits) %>% 
  lapply(function(x) {
    if (length(all_driver_fits[[x]]) > 0) {
      lapply(names(all_driver_fits[[x]]),function(y) {
        pars <- all_driver_fits[[x]][[y]]$non_linear_estimate$m$getAllPars()
        earliest <- all_driver_fits[[x]][[y]]$earliest_driver$coefficients[2]
        latest <- all_driver_fits[[x]][[y]]$latest_driver$coefficients[2]
        
        early_cp <- all_driver_fits[[x]][[y]]$change_point_regression$par[2]
        
        m <- all_driver_fits[[x]][[y]]$data %>%
          subset(Gen == min(Gen))
        v <- m$Count / 2e5
        time <- m$Gen
        m <- -log((1 - v)/v) - early_cp * time
        expected_last_vaf <- 1 / (1 + exp(-(m + early_cp * 800)))
        return(
          data.frame(
            xmid = pars[1],
            growth_rate = 1/pars[2],
            early_growth = all_driver_fits[[x]][[y]]$change_point_regression$par[2],
            late_growth = all_driver_fits[[x]][[y]]$change_point_regression$par[3],
            earliest_driver = earliest,
            latest_driver = latest,
            changepoint_wf = all_driver_fits[[x]][[y]]$change_point_regression$par[4],
            point_of_saturation_wf = which(
              c(
                predict(
                  all_driver_fits[[x]][[y]]$non_linear_estimate,
                  newdata = data.frame(Gen = seq(1,800)))) > (0.9 * pars[1])) %>% min,
            name = x,
            clade = y,
            last_vaf = all_driver_fits[[x]][[y]]$last_vaf,
            expected_last_vaf,
            gen_at_onset = all_driver_fits[[x]][[y]]$gen_at_onset
          )
        )
      }) %>%
        do.call(what = rbind)
    }
  }) %>%
  do.call(what = rbind) %>%
  mutate(late_growth = early_growth + late_growth)
## Warning in min(.): no non-missing arguments to min; returning Inf

3 Comparing BNPR trajectory models

3.1 Plotting BNPR trajectories

Here, we plot all of the inferred BNPR trajectories for each clade.

all_fits %>%
  do.call(what = rbind) %>%
  mutate(fitness = str_match(obj_name,"[0-9.]+")) %>% 
  group_by(n,obj_name) %>%
  mutate(V = mean(w)) %>%
  mutate(x = x - min(x)) %>% 
  subset(V <= 5) %>% 
  mutate(fitness = factor(
    fitness,
    levels = seq(0.005,0.03,by = 0.005),
    labels = sprintf("s=%s",seq(0.005,0.03,by = 0.005)))) %>% 
  ggplot(aes(x = x,y = y)) + 
  geom_line(aes(group = n),alpha = 0.25,size = 0.25) + 
  geom_hline(yintercept = 2e5,size = 0.25) +
  facet_wrap(~ fitness) + 
  scale_y_continuous(trans = 'log10',expand = c(0,0),breaks = c(1e-2,1,1e2,1e4,1e6)) +
  scale_x_continuous(expand = c(0,0)) +
  theme_gerstung(base_size = 6) + 
  theme(strip.text = element_text(margin = margin(b = 0.5))) +
  ylab("Effective population size (EPS)") + 
  xlab("Time since first coal. (generations)") + 
  coord_cartesian(ylim = c(0.5,1e7)) +
  ggsave("figures/simulations/trajectories_BNPR.pdf",height = 1.5,width = 2.3)

3.2 Comparing BNPR and WF trajectories

Here we scale the BNPR trajectories so that they match as closely as possible those obtained from the WF trajectories. A fairly good fit between BNPR and WF trajectories is observable for a large fraction of trajectories.

simulated_trajectories <- list()
bnpr_estimates_ <- list()
for (x in names(all_driver_fits)) {
  if (length(all_driver_fits[[x]]) > 0) {
    for (i in 1:length(all_driver_fits[[x]])) {
      driver_fit <- all_driver_fits[[x]][[i]]$data
      bnpr_sum <- all_estimates[[x]][[i]]$bnpr$summary
      W <- (log(bnpr_sum$quant0.975) - log(bnpr_sum$quant0.025))^2/16
      bnpr_sum$time <- (800 - bnpr_sum$time)
      driver_fit$id <- x
      driver_fit$clade <- i
      bnpr_sum$id <- x
      bnpr_sum$clade <- i
      simulated_trajectories[[length(simulated_trajectories)+1]] <- driver_fit
      bnpr_estimates_[[length(bnpr_estimates_)+1]] <- bnpr_sum
      }
  }
}

# for (x in names(all_driver_fits)) {
#   if (x %in% names(all_estimates)) {
#     for (y in names(all_driver_fits[[x]])) {
#       clade_no <- 1
#       bnpr_sums <- list()
#       largest_clade <- which.max(unlist(lapply(all_estimates[[x]],function(x) length(x$clade))))[1]
#       for (estimate in all_estimates[[x]]) {
#         if (y == estimate$driver) {
#           driver_fit <- all_driver_fits[[x]][[y]]$data
#           bnpr_sum <- estimate$bnpr$summary
#           W <- (log(bnpr_sum$quant0.975) - log(bnpr_sum$quant0.025))^2/16
#           bnpr_sum$time <- (800 - bnpr_sum$time)
#           if (clade_no == largest_clade) {
#             opt_data <- data.frame(
#               a = approx(bnpr_sum$time,y = bnpr_sum$mean,xout = driver_fit$Gen)$y,
#               b = driver_fit$Count) %>%
#               na.omit
#             opt_fn <- function(pars) {
#               l <- (opt_data$a * pars[1] - opt_data$b)^2
#               l <- l
#               return(sum(l))
#             }
#             opt_solution <- optim(c(m = 1),
#                                   opt_fn,
#                                   control = list(maxit = 10000),
#                                   method = "Nelder-Mead")
#             P <- opt_solution$par
#           }
#           driver_fit$id <- x
#           driver_fit$clade <- y
#           bnpr_sum$id <- x
#           bnpr_sum$clade <- y
#           bnpr_sum$clade <- clade_no
#           simulated_trajectories[[length(simulated_trajectories)+1]] <- driver_fit
#           bnpr_sums[[length(bnpr_sums)+1]] <- bnpr_sum
#           clade_no <- clade_no + 1
#         }
#       }
#       for (bnpr_sum in bnpr_sums) {
#         bnpr_sum$transformed_y <- bnpr_sum$mean * P[1]
#         bnpr_sum$transformed_y_low <- bnpr_sum$quant0.025 * P[1]
#         bnpr_sum$transformed_y_high <- bnpr_sum$quant0.975 * P[1]
#         bnpr_estimates_[[length(bnpr_estimates_)+1]] <- bnpr_sum
#       }
#     }
#   }
# }

simulated_trajectories_df <- do.call(rbind,simulated_trajectories)
bnpr_estimates_df <- do.call(rbind,bnpr_estimates_) %>% 
  mutate(w = (log(quant0.975) - log(quant0.025))^2/16) %>% 
  group_by(id, clade) %>% 
  mutate(w_sum = mean(w)) %>%
  filter(w_sum < 5)

ggplot(data = NULL) +
  geom_hline(yintercept = 2e5,size = 0.25,colour = "goldenrod",alpha = 0.7) +
  geom_line(data = simulated_trajectories_df,
            aes(x = Gen/100,y = Count,colour = "Simulated",
                group = paste(clade)),
            size = 0.25,
            alpha = 0.7) + 
  geom_line(data = bnpr_estimates_df,aes(x = time/100,y = mean,
                                         group = paste(clade),colour = "BNPR estimate"),
            size = 0.25,
            alpha = 0.5) + 
  facet_wrap(~ id,scales = "free") + 
  theme_gerstung(base_size = 6) + 
  theme(strip.text = element_blank(),legend.position = "bottom") + 
  scale_y_continuous(trans = 'log10') + 
  coord_cartesian(ylim = c(NA,2e8),c(0,8)) +
  scale_colour_manual(values = c(Simulated = "grey50",`BNPR estimate` = "orchid"),name = NULL) + 
  xlab("Generation (x100)") + 
  ggsave("figures/simulations/trajectories_wf_bnpr.pdf",height = 6,width = 7)

set.seed(42)
S <- sample(unique(paste(bnpr_estimates_df$id,bnpr_estimates_df$clade)),
            size = 23,replace = F)

bnpr_ss <- subset(bnpr_estimates_df,paste(id,clade) %in% S)
sim_ss <- subset(simulated_trajectories_df,paste(id,clade) %in% S)

similarity <- lapply(S,function(x) {
  a <- bnpr_ss %>%
    subset(paste(id,clade) == x) %>%
    mutate(age_group = floor(time / 100)) %>%
    group_by(age_group) %>%
    summarise(av_neff = mean(mean))
  b <- sim_ss %>%
    subset(paste(id,clade) == x) %>%
    mutate(age_group = floor(Gen / 100)) %>%
    group_by(age_group,id,clade) %>%
    summarise(av_n = mean(Count))
  merge(a,b,by = "age_group") %>%
    group_by(id,clade) %>%
    summarise(mae = mean(abs(av_neff - av_n))) %>%
    return
}) %>%
  do.call(what = rbind) %>%
  arrange(mae) %>%
  mutate(b = paste(id,clade))

ggplot(data = NULL) +
  geom_hline(yintercept = 2e5,size = 0.25,colour = "goldenrod",alpha = 0.7) +
  geom_line(data = sim_ss,
            aes(x = Gen/100,y = Count,colour = "Simulated",fill = "Simulated",
                group = paste(clade)),
            size = 0.25,
            alpha = 0.9) + 
  geom_ribbon(data = bnpr_ss,
              aes(x = time/100,ymin = quant0.025,ymax = quant0.975,
                  group = paste(clade),fill = "BNPR estimate"),
              alpha = 0.2,
              size = 0) +
  geom_line(data = bnpr_ss,
            aes(x = time/100,y = mean,
                group = paste(clade),colour = "BNPR estimate",fill = "BNPR estimate"),
            size = 0.25,
            alpha = 0.7) + 
  facet_wrap(~ factor(paste(id,clade),levels = similarity$b),scales = "free") + 
  theme_gerstung(base_size = 6) + 
  theme(strip.text = element_blank(),legend.position = "bottom",
        legend.key.size = unit(0.2,"cm"),legend.box.spacing = unit(0,"cm")) + 
  scale_y_continuous(trans = 'log10') + 
  coord_cartesian(ylim = c(1,2e8),xlim = c(0,8)) +
  scale_colour_manual(values = c(Simulated = "black",`BNPR estimate` = "orchid"),name = NULL) + 
  scale_fill_manual(values = c(Simulated = "black",`BNPR estimate` = "orchid"),name = NULL) + 
  xlab("Generation (x100)") + 
  ggsave("figures/simulations/trajectories_wf_bnpr_subset.pdf",height = 2.2,width = 3)
## Warning: Ignoring unknown aesthetics: fill

## Warning: Ignoring unknown aesthetics: fill

3.3 Sensitivity analysis - normal vs. trimmed

Here we see if using trimmed leads to very different estimates. Trimming the data by 5 or even 10 years does not lead to very different estimations for the late growth rate as we propose here. Indeed, the average difference we can expect between trimmed-by-10-years and non-trimmed data is 0.8% growth per generation for late growth estimates and 0.4% growth per generation for early growth estimates.

all_fits_df <- all_fits %>%
  do.call(what = rbind) %>%
  mutate(fitness = str_match(obj_name,"[0-9.]+")) %>% 
  group_by(n,obj_name) %>%
  mutate(V = mean(w)) %>%
  mutate(x = x - min(x)) %>% 
  subset(V <= 10) %>% 
  mutate(fitness = factor(
    fitness,
    levels = seq(0.005,0.03,by = 0.005),
    labels = sprintf("s=%s",seq(0.005,0.03,by = 0.005))))

diff_fits <- rbind(
  all_bnpr_fits %>%
    select(name,clade,early = cp_1,late = cp_late,earliest,latest,w_sum) %>%
    mutate(Preprocessing = "Original trajectory"),
  all_bnpr_fits_trimmed_5 %>%
    select(name,clade,early = cp_1,late = cp_late,earliest,latest,w_sum) %>%
    mutate(Preprocessing = "Trimmed by 5 years"),
  all_bnpr_fits_trimmed %>%
    select(name,clade,early = cp_1,late = cp_late,earliest,latest,w_sum) %>%
    mutate(Preprocessing = "Trimmed by 10 years")
) %>%
  mutate(Preprocessing = factor(Preprocessing,
                                levels = c("Original trajectory",
                                           "Trimmed by 5 years",
                                           "Trimmed by 10 years"))) %>%
  group_by(clade,name) %>% 
  filter(all(w_sum < 5))

df_benchmark <- rbind(
  diff_fits %>% 
  select(late,Preprocessing,name,clade) %>%
  spread(value = "late",key = "Preprocessing") %>% 
  ungroup %>% 
  summarise(
    `Original-5 years` = (`Original trajectory` - `Trimmed by 5 years`),
    `Original-10 years` = (`Original trajectory` - `Trimmed by 10 years`),
    `5 years-10 years` = (`Trimmed by 5 years` - `Trimmed by 10 years`)) %>%
    mutate(id = "Changepoint late"),
  diff_fits %>% 
    select(early,Preprocessing,name,clade) %>%
    spread(value = "early",key = "Preprocessing") %>% 
    ungroup %>% 
    summarise(
      `Original-5 years` = (`Original trajectory` - `Trimmed by 5 years`),
      `Original-10 years` = (`Original trajectory` - `Trimmed by 10 years`),
      `5 years-10 years` = (`Trimmed by 5 years` - `Trimmed by 10 years`)) %>%
    mutate(id = "Changepoint early"),
  diff_fits %>% 
    select(latest,Preprocessing,name,clade) %>%
    spread(value = "latest",key = "Preprocessing") %>% 
    ungroup %>% 
    summarise(
      `Original-5 years` = (`Original trajectory` - `Trimmed by 5 years`),
      `Original-10 years` = (`Original trajectory` - `Trimmed by 10 years`),
      `5 years-10 years` = (`Trimmed by 5 years` - `Trimmed by 10 years`)) %>%
    mutate(id = "Last 150 generations"),
  diff_fits %>% 
    select(earliest,Preprocessing,name,clade) %>%
    spread(value = "earliest",key = "Preprocessing") %>% 
    ungroup %>% 
    summarise(
      `Original-5 years` = (`Original trajectory` - `Trimmed by 5 years`),
      `Original-10 years` = (`Original trajectory` - `Trimmed by 10 years`),
      `5 years-10 years` = (`Trimmed by 5 years` - `Trimmed by 10 years`)) %>%
    mutate(id = "First 150 generations")) 

df_benchmark %>%
  group_by(id) %>%
  summarise(`Original-5 years` = mean(abs(`Original-5 years`),na.rm=T),
            `Original-10 years` = mean(abs(`Original-10 years`),na.rm=T),
            `5 years-10 years` = mean(abs(`5 years-10 years`),na.rm=T))
df_benchmark %>% 
  gather(key = "key",value = "value",-id) %>% 
  subset(id %in% c("Changepoint late","Changepoint early")) %>%
  subset(grepl("Original",key)) %>% 
  ggplot(aes(x = key, y = value,colour = id)) + 
  geom_jitter(size = 0.25,position = position_dodge(width = 1),alpha = 0.25) +
  geom_boxplot(outlier.alpha = 0,size = 0.25,position = position_dodge(width = 1)) + 
  theme_gerstung(base_size = 6) + 
  coord_flip() + 
  theme(legend.position = "bottom") + 
  scale_colour_aaas(name = NULL) +
  xlab("") + 
  ylab("Growth rate estimate") + 
  guides(colour = guide_legend(nrow = 2)) + 
  theme(legend.key.height = unit(0,"cm"),
        legend.box.spacing = unit(0,"cm")) + 
  scale_y_continuous(limits = c(NA,0)) +
  ggsave("figures/simulations/benchmark_trim.pdf",height = 1,width = 2)
## Warning: Removed 332 rows containing non-finite values (stat_boxplot).
## Warning: Removed 332 rows containing missing values (geom_point).
## Warning: Removed 332 rows containing non-finite values (stat_boxplot).
## Warning: Removed 332 rows containing missing values (geom_point).

3.4 Comparing inference of early and late growth between BNPR fits and driver fits from WF simulations

We define here a measure of excessive observed variance in the BNPR trajectory - whenever the average log-EPS variance is larger than 10 we define these trajectories as "high variance" trajectories, noting that they are trajectories which were obtained from fairly small clades (fewer than 10 tips).

bnpr_wf_merge_df <- merge(all_bnpr_fits,all_driver_fits_df,by = c("name","clade")) %>%
  mutate(vaf_at_detection_low = pop_at_detection_low/2e5,
         vaf_at_detection_high = pop_at_detection_high/2e5) %>%
  mutate(vaf_at_detection_low = ifelse(vaf_at_detection_low > 1,1,vaf_at_detection_low),
         vaf_at_detection_high = ifelse(vaf_at_detection_high > 1,1,vaf_at_detection_high)) %>%
  mutate(vaf_contained = last_vaf > vaf_at_detection_low & last_vaf < vaf_at_detection_high) %>%
  mutate(vaf_contained_higher_lower = ifelse(
    vaf_contained == F,
    ifelse(last_vaf > vaf_at_detection_high,"higher","lower"),
    "contained")) %>%
  mutate(col = ifelse(w_sum < 5,"Low variance","High variance")) %>%
  mutate(wf_ratio = late_growth/early_growth,
         bnpr_ratio = cp_late/cp_1) %>% 
  mutate(bnpr_ratio = ifelse(is.infinite(bnpr_ratio),NA,bnpr_ratio),
         wf_ratio = ifelse(is.infinite(wf_ratio),NA,wf_ratio)) 

bnpr_wf_merge_df %>% 
  ggplot(aes(x = w_sum)) +
  geom_density() +
  scale_x_continuous(trans = 'log10') + 
  theme_gerstung(base_size = 6) + 
  xlab("Average variance")

ggplot(bnpr_wf_merge_df) +
  geom_boxplot(aes(x = col,y = clade_size),outlier.size = 0.5) + 
  scale_y_continuous(breaks = seq(0,100,by = 10)) + 
  theme_gerstung(base_size = 6) + 
  xlab("Trajectory type") + 
  ylab("Clade size")

3.5 Comparing the expected and observed late growth

We also want to have a measure of growth deceleration that is more comparable with our inferences from the time series data - based on a sigmoidal fit with a fixed carrying capacity (\(VAF = 0.5\)) that is then extrapolated to its onset of origin. As such, we fit a sigmoidal to the BNPR that assumes a total carrying capacity of \(1\) and that the final EPS corresponds to a proxy for the VAF that is parametrised as \(\hat{VAF} = \frac{Number\ Of\ Tips\ In\ The\ Clade}{Total\ Tips\ In\ The\ Tree}\).

expected_late_growth <- function(vaf,id,clade) {
  clade <- as.numeric(clade)
  if (clade <= length(all_estimates_coalescence[[id]])) {
    
    tmp_df <- all_estimates_coalescence[[id]][[clade]]$data %>%
      arrange(X)
    last_eps <- tmp_df$Y_exp[nrow(tmp_df)]
    last_eps <- max(tmp_df$Y_exp)
    tmp_df$normalized_Y <- tmp_df$Y_exp / last_eps * vaf

    nl_fit <- nlsLM(formula = normalized_Y ~ SSlogis(X,1,Xmid,b),
                    start = list(Xmid = mean(tmp_df$X),b = 100),
                    control = nls.control(maxiter = 1000,minFactor = 1/(1024^32),
                                          warnOnly = T),
                    data = tmp_df)
    return(nl_fit$m$getAllPars()[2])
  } else {
    return(NA)
  }
}

bnpr_wf_merge_df$expected_nl <- rep(NA,nrow(bnpr_wf_merge_df))

for (i in 1:nrow(bnpr_wf_merge_df)) {
  sub_bnpr_wf_merge_df <- bnpr_wf_merge_df[i,]
  last_vaf <- sub_bnpr_wf_merge_df$last_vaf
  name <- as.character(sub_bnpr_wf_merge_df$name)
  clade <- sub_bnpr_wf_merge_df$clade
  bnpr_wf_merge_df$expected_nl[i] <- 1/expected_late_growth(last_vaf,name,clade)
}

bnpr_wf_merge_df %>% 
  subset(col == "Low variance") %>% 
  mutate(expected_late_ratio_wf = late_growth/growth_rate,
         expected_late_ratio = cp_late / expected_nl) %>%
  mutate(expected_late_ratio = ifelse(expected_late_ratio < -1,-1,expected_late_ratio)) %>% 
  ggplot(aes(x = ifelse(expected_late_ratio_wf < 0.8,"Saturating","Near constant"),
             y = expected_late_ratio)) +
  geom_hline(yintercept = 0.5) +
  geom_jitter(width = 0.3,size = 0.5,alpha = 0.8) +
  geom_boxplot(outlier.alpha = 0,alpha = 0.8) + 
  theme_gerstung(base_size = 6) +
  scale_y_continuous(breaks = c(-1,0,1,2),
                     labels = c("<-1","0","1","2")) + 
  ylab("BNPR late/expected growth ratio") +
  xlab("")
## Warning: Removed 24 rows containing non-finite values (stat_boxplot).
## Warning: Removed 24 rows containing missing values (geom_point).

bnpr_wf_merge_df %>% 
  subset(col == "Low variance") %>% 
  mutate(expected_late_ratio = cp_late / expected_nl) %>% 
  mutate(expected_late_ratio = ifelse(expected_late_ratio < -1,-1,expected_late_ratio)) %>% 
  ggplot(aes(y = expected_late_ratio,x=0)) + 
  geom_jitter(size = 0.5,alpha = 0.5,width = 0.05) + 
  geom_boxplot(size = 0.25,width = 0.1,outlier.size = 0,alpha = 0.8) + 
  theme_gerstung(base_size = 6) + 
  geom_hline(yintercept = 1,linetype = 2) +
  #coord_cartesian(ylim = c(0,5)) + 
  scale_y_continuous(breaks = c(-1,0,1,2),
                     labels = c("<-1","0","1","2")) + 
  ylab("Observed/expected late growth ratio") +
  xlab("") +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank()) +
  ggsave("figures/simulations/expected_late_growth_ratio.pdf",height = 1.5,width = 0.7)
## Warning: Removed 24 rows containing non-finite values (stat_boxplot).
## Warning: Removed 24 rows containing missing values (geom_point).
## Warning: Removed 24 rows containing non-finite values (stat_boxplot).
## Warning: Removed 24 rows containing missing values (geom_point).

3.6 Benchmarking different models for BNPR fits

Finally we compare all of the methods specified above and make statements regarding their usability in predicting the true fitness of each clade. We note that log-linear fits have a tendency to underestimate the growth rate of clones with large fitness values, whereas sigmoidal curves have a tendency to do the opposite - overestimate the growth rate of clones with small fitness values. The early growth rate stands as the most adequate solution for our problem, with \(R^2 > 0.3\) for fits with low variance. We must note that we still observe an \(RMSE = 0.01\) which would, assuming 10 generations per year, correspond to a variation of \(\pm 10%\). As such, these inferences should account for this considerable level of uncertainty that is associated with each estimate.

fits <- all_bnpr_fits
fits <- fits %>% 
  mutate(col = ifelse(w_sum < 5,"Low variance","High variance")) 

limits <- c(0,0.05)
plot_list <- list()
metric_list <- list()

for (col_name in c("l","cp_1","cp_late","nl")) {
  sub_data <- fits %>%
    ungroup() %>%
    select(fitness,R = col_name) %>%
    na.omit()
  sub_data_low <- fits %>%
    subset(col == "Low variance") %>%
    ungroup() %>%
    select(fitness,R = col_name) %>%
    na.omit() 
  
  plot_list[[col_name]] <- fits %>%
    select(fitness,R = col_name,col = col) %>% 
    ggplot(aes(x = fitness,y = R,group = paste(fitness),colour = col)) +
    geom_jitter(width = 0.001,size = 0.5) +
    geom_boxplot(outlier.alpha = 0,size = 0.25,alpha = 0.7,colour = "black") +
    theme_gerstung(base_size = 6) +
    coord_cartesian(xlim = limits,
                    ylim = limits) +
    geom_abline(slope = 1) +
    ylab("Inferred growth") +
    xlab("Simulated growth") + 
    scale_colour_lancet(guide = F)
  
  metric_list[[col_name]] <- data.frame(
    method = col_name,
    s = c("All trajectories","Low variance"),
    R = c(cor(sub_data$fitness,sub_data$R),
          cor(sub_data_low$fitness,sub_data_low$R)),
    mse = c(mean((sub_data$fitness - sub_data$R)^2),
            mean((sub_data_low$fitness - sub_data_low$R)^2))
  )
}
## Note: Using an external vector in selections is ambiguous.
## ℹ Use `all_of(col_name)` instead of `col_name` to silence this message.
## ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
## This message is displayed once per session.
## Adding missing grouping variables: `name`
## Adding missing grouping variables: `name`
## Adding missing grouping variables: `name`
## Adding missing grouping variables: `name`
metric_df <- do.call(rbind,metric_list) %>%
  mutate(
    
    met = list(l = "Linear\ngrowth",
               cp_1 = "Early linear\ngrowth",
               cp_late = "Late linear\ngrowth",
               nl = "Sigmoidal\ngrowth")[method] %>%
      unlist
  )

r2_plot <- ggplot(data = metric_df,aes(x = reorder(met,-R^2),y = R^2,fill = s)) + 
  geom_bar(stat = "identity",position = "dodge") + 
  xlab("") + 
  ylab("R2") + 
  theme_gerstung(base_size = 6) +
  coord_flip() + 
  scale_y_continuous(expand = c(0,0)) + 
  scale_fill_lancet(guide = F)

mse_plot <- ggplot(data = metric_df,aes(x = reorder(met,-mse),y = sqrt(mse),fill = s)) + 
  geom_bar(stat = "identity",position = "dodge") + 
  xlab("") + 
  ylab("RMSE") + 
  theme_gerstung(base_size = 6) +
  coord_flip() + 
  scale_y_continuous(expand = c(0,0)) + 
  scale_fill_lancet(guide = F)

plot_grid(
  plot_list$l + ggtitle("Linear growth"),
  plot_list$cp_1 + ggtitle("Early linear growth"),
  plot_list$cp_late + ggtitle("Late linear growth"),
  plot_list$nl + ggtitle("Sigmoidal growth"),
  r2_plot,
  mse_plot,
  rel_heights = c(1,1,0.6),
  ncol = 2) + 
  ggsave("figures/simulations/validation_BNPR.pdf",height = 5,width = 4.7,useDingbats=FALSE)

write.csv(bnpr_wf_merge_df,"data_output/simulated_bnpr_coefficients.csv",row.names = F)