Generating Samples of Gene Expression Data with Variational Autoencoders



Introduction

A fundamental problem in biomedical research is the low number of observations, mostly due to a lack of available biosamples, prohibitive costs, or ethical reasons. By augmenting a few real observations with artificially generated samples, their analysis could lead to more robust and higher reproducibility. One possible solution to the problem is the use of generative models, which are statistical models of data that attempt to capture the entire probability distribution from the observations. Using the variational autoencoder (VAE), a well-known deep generative model, this package is aimed to generate samples with gene expression data, especially for single-cell RNA-seq data. Furthermore, the VAE can use conditioning to produce specific cell types or subpopulations. The conditional VAE (CVAE) allows us to create targeted samples rather than completely random ones.

Autoencoders are an unsupervised neural networks that perform data compression from multidimensional to a preferred dimensionality. They reconstruct input data using the hidden layer weights calculated by encoding. The basic idea of an autoencoder is to have an output layer with the same dimensionality as the inputs. The idea is to try to reconstruct each dimension exactly by passing it through the network. It is common but not necessary for an autoencoder to have a symmetric architecture between the input and output. The number of units in the middle layer is typically fewer than that in the input or output. After training an autoencoder, it is not necessary to use both the encoder and decoder portions. For example, when using the approach for dimensionality reduction, one can use the encoder portion in order to create the reduced representations of the data. The reconstructions of the decoder might not be required at all. As a result, an autoencoder is capable of performing dimension reduction. The objective function of this neural network encompasses reconstruction loss. The loss function uses the sum of squared differences between the input and the output in order to force the output to be as similar as possible to the input. Also, the cross-entropy can used as a loss function for quantifying the difference between two probability distributions.

Another interesting application of the autoencoder is one in which we use only the decoder portion of the network. Variational autoencoders are based on Bayesian inference in which the compressed representation follows probability distribution. This constraint differentiates the VAE from standard autoencoder. The VAE can generate new data while conventional autoencoders fail. For example, one might add a term to the loss function to enforce the fact that the hidden variables are drawn from a Gaussian distribution. Then, one might repeatedly draw samples from this Gaussian distribution and use only the decoder portion of the network in order to generate samples of the original data. In this autoencoder, bottleneck vector (latent vector) is replaced by two vectors, namely, mean vector and standard deviation vector. The overall loss function J = L + λR of the VAE is expressed as a weighted sum of the reconstruction loss L and the regularization loss R, where λ > 0 is the regularization parameter. The term “variational” comes from the close relationship between the regularization and the variational inference method in statistics. One can use a variety of choices for the reconstruction error, and we will use the binary cross-entropy loss between the input and output. The regularization loss is simply the Kullback-Leibler divergence measure of the conditional distributions of the hidden representations of particular points with respect to the standard multivariate Gaussian distribution. Small values of λ will favor exact reconstruction, and the approach will behave like a traditional autoencoder.

One can apply conditioning to variational autoencoders in order to obtain some interesting results. The basic idea in the conditional variational autoencoder is to add an additional conditional input. From an implementation perspective, we can encode category information as a one-hot representation, indicating to the model which class is at the input. One can use an autoencoder for embedding multimodal data in a joint latent space. Multimodal data is essentially data in which the input features are heterogeneous. In addition, by separating the samples into different classes, the data points within the same category become more similar, enhancing the modeling capacity and sample quality of the CVAE.



Example

VAE

Consider artificial data. The data consist of 1000 genes and three groups of 100 samples. Each group has 100 differentially expressed genes.

if (keras::is_keras_available() & reticulate::py_available()) {
    library(VAExprs)
    
    ### simulate differentially expressed genes
    set.seed(1)
    g <- 3
    n <- 100
    m <- 1000
    mu <- 5
    sigma <- 5
    mat <- matrix(rnorm(n*m*g, mu, sigma), m, n*g)
    rownames(mat) <- paste0("gene", seq_len(m))
    colnames(mat) <- paste0("cell", seq_len(n*g))
    group <- factor(sapply(seq_len(g), function(x) { 
        rep(paste0("group", x), n)
    }))
    names(group) <- colnames(mat)
    mu_upreg <- 6
    sigma_upreg <- 10
    deg <- 100
    for (i in seq_len(g)) {
        mat[(deg*(i-1) + 1):(deg*i), group == paste0("group", i)] <- 
            mat[1:deg, group==paste0("group", i)] + rnorm(deg, mu_upreg, sigma_upreg)
    }
    # positive expression only
    mat[mat < 0] <- 0
    x_train <- as.matrix(t(mat))
    
    # heatmap
    heatmap(mat, Rowv = NA, Colv = NA, 
            col = colorRampPalette(c('green', 'red'))(100), 
            scale = "none")
}

The VAE model can be built by using the function “fit_vae” with gene expression data and the cell annotation from the object “sce”. The overall loss function of the VAE is expressed as a weighted sum of the reconstruction loss and the regularization loss. The reconstruction loss is the binary cross-entropy loss between the input and output and the regularization loss is simply the Kullback-Leibler divergence measure. Note that the same dataset is used for training and validation.

if (keras::is_keras_available() & reticulate::py_available()) {
    # model parameters
    batch_size <- 32
    original_dim <- 1000
    intermediate_dim <- 512
    epochs <- 100
    
    # VAE
    vae_result <- fit_vae(x_train = x_train, x_val = x_train,
                        encoder_layers = list(layer_input(shape = c(original_dim)),
                                            layer_dense(units = intermediate_dim,
                                                        activation = "relu")),
                        decoder_layers = list(layer_dense(units = intermediate_dim,
                                                        activation = "relu"),
                                            layer_dense(units = original_dim,
                                                        activation = "sigmoid")),
                        epochs = epochs, batch_size = batch_size,
                        use_generator = FALSE,
                        callbacks = keras::callback_early_stopping(
                            monitor = "val_loss",
                            patience = 10,
                            restore_best_weights = TRUE))
}

The function “plot_vae” draws the plot for model architecture.

if (keras::is_keras_available() & reticulate::py_available()) {
    # model architecture
    plot_vae(vae_result$model)
}

The function “gen_exprs” can generate samples with expression data by using the trained model.

if (keras::is_keras_available() & reticulate::py_available()) {
    # sample generation
    set.seed(1)
    gen_sample_result <- gen_exprs(vae_result, num_samples = 100)
    
    # heatmap
    heatmap(cbind(t(x_train), t(gen_sample_result$x_gen)),
            col = colorRampPalette(c('green', 'red'))(100),
            Rowv=NA)
}

The function “plot_aug” uses reduced dimension plots for augmented data visualization.

if (keras::is_keras_available() & reticulate::py_available()) {
    # plot for augmented data
    plot_aug(gen_sample_result, "PCA")
}



CVAE

The “yan” data set is single-cell RNA sequencing data with 20214 genes and 90 cells from human preimplantation embryos and embryonic stem cells at different passages. The rows in the dataset correspond to genes and columns correspond to cells. The “SingleCellExperiment” class can be used to store and manipulate single-cell genomics data. It extends the “RangedSummarizedExperiment” class and follows similar conventions. The object “sce” can be created by the data “yan” with cell type annotation “ann”.

if (keras::is_keras_available() & reticulate::py_available()) {
    library(VAExprs)
    library(SC3)
    library(SingleCellExperiment)
    
    # create a SingleCellExperiment object
    sce <- SingleCellExperiment::SingleCellExperiment(
        assays = list(counts = as.matrix(yan)),
        colData = ann
    )
    
    # define feature names in feature_symbol column
    rowData(sce)$feature_symbol <- rownames(sce)
    # remove features with duplicated names
    sce <- sce[!duplicated(rowData(sce)$feature_symbol), ]
    # remove genes that are not expressed in any samples
    sce <- sce[which(rowMeans(assay(sce)) > 0),]
    dim(assay(sce))
    
    # model parameters
    batch_size <- 32
    original_dim <- 19595
    intermediate_dim <- 256
    epochs <- 100
    
    # model
    cvae_result <- fit_vae(object = sce,
                        encoder_layers = list(layer_input(shape = c(original_dim)),
                                            layer_dense(units = intermediate_dim,
                                                        activation = "relu")),
                        decoder_layers = list(layer_dense(units = intermediate_dim,
                                                        activation = "relu"),
                                            layer_dense(units = original_dim,
                                                        activation = "sigmoid")),
                        epochs = epochs, batch_size = batch_size,
                        use_generator = TRUE,
                        callbacks = keras::callback_early_stopping(
                            monitor = "loss",
                            patience = 20,
                            restore_best_weights = TRUE))
    
    # model architecture
    plot_vae(cvae_result$model)
}
if (keras::is_keras_available() & reticulate::py_available()) {
    # sample generation
    set.seed(1)
    gen_sample_result <- gen_exprs(cvae_result, 100,
                                batch_size, use_generator = TRUE)
    
    # plot for augmented data
    plot_aug(gen_sample_result, "PCA")
}



Session information

sessionInfo()
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: Etc/UTC
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] rmarkdown_2.29
## 
## loaded via a namespace (and not attached):
##  [1] vctrs_0.6.5       cli_3.6.3         knitr_1.49        rlang_1.1.4      
##  [5] zeallot_0.1.0     xfun_0.50         png_0.1-8         generics_0.1.3   
##  [9] jsonlite_1.8.9    glue_1.8.0        buildtools_1.0.0  keras_2.15.0     
## [13] rprojroot_2.0.4   htmltools_0.5.8.1 maketools_1.3.1   sys_3.4.3        
## [17] sass_0.4.9        rappdirs_0.3.3    grid_4.4.2        tfruns_1.5.3     
## [21] evaluate_1.0.3    jquerylib_0.1.4   base64enc_0.1-3   fastmap_1.2.0    
## [25] yaml_2.3.10       lifecycle_1.0.4   whisker_0.4.1     compiler_4.4.2   
## [29] Rcpp_1.0.14       here_1.0.1        lattice_0.22-6    digest_0.6.37    
## [33] R6_2.5.1          reticulate_1.40.0 pillar_1.10.1     tensorflow_2.16.0
## [37] magrittr_2.0.3    bslib_0.8.0       Matrix_1.7-1      tools_4.4.2      
## [41] cachem_1.1.0



References

Aggarwal, C. C. (2018). Neural Networks and Deep Learning. Springer.

Al-Jabery, K., Obafemi-Ajayi, T., Olbricht, G., & Wunsch, D. (2019). Computational Learning Approaches to Data Analytics in Biomedical Applications. Academic Press.

Cinelli, L. P., Marins, M. A., da Silva, E. A. B., & Netto, S. L. (2021). Variational Methods for Machine Learning with Applications to Deep Networks. Springer.

Das, H., Pradhan, C., & Dey, N. (2020). Deep Learning for Data Analytics: Foundations, Biomedical Applications, and Challenges. Academic Press.

Marouf, M., Machart, P., Bansal, V., Kilian, C., Magruder, D. S., Krebs, C. F., & Bonn, S. (2020). Realistic in silico generation and augmentation of single-cell RNA-seq data using generative adversarial networks. Nature communications, 11(1), 1-12.

Pedrycz, W., & Chen, S. M. (Eds.). (2020). Deep Learning: Concepts and Architectures. Springer.

Yan, W. Q. (2020). Computational Methods for Deep Learning: Theoretic, Practice and Applications. Springer.