Create an example data set with 4 groups, 400 train + 100 test samples and 800 features.
set.seed(123)
data <- makeExampleData(n = 500, p=800, g=4,
pis=c(0.05, 0.1, 0.05, 0.1),
gammas=c(0.1, 0.1, 10, 10))
# training data set
Xtrain <- data$X[1:400, ]
ytrain <- data$y[1:400]
# annotations of features to groups
annot <- data$annot
# test data set
Xtest <- data$X[401:500, ]
ytest <- data$y[401:500]
graper
is the main function of this package, which
allows to fit the proposed Bayesian models with different settings on
the prior (by setting spikeslab
to FALSE or TRUE) and the
variational approximation (by setting factoriseQ
to FALSE
or TRUE). By default, the model is fit with a sparsity promoting
spike-and-slab prior and a fully-factorised mean-field assumption. The
parameter n_rep
can be used to train multiple models with
different random initializations. The best model is then chosen in terms
of ELBO and returned by the function. th
defines the
threshold on the ELBO for convergence in the variational Bayes (VB)
algorithm used for optimization.
## Fitting a model with 4 groups, 400 samples and 800 features.
## Fitting with random init 1
## ELB converged
## Fitting with random init 2
## ELB converged
## Fitting with random init 3
## ELB converged
## Sparse graper object for a linear regression model with 800 predictors in 4 groups.
## Group-wise shrinkage:
## 1 2 3 4
## 0.22 0.07 8.25 7.13
## Group-wise sparsity (1 = dense, 0 = sparse):
## 1 2 3 4
## 0.06 0.14 0.04 0.1
The variational Bayes (VB) approach directly yields posterior distributions for each parameter. Note, however, that using VB these are often too concentrated and cannot be directly used for construction of confidence intervals etc. However, they can provide good point estimates.
The estimated coefficients and the intercept are contained in the result list.
# get coefficients (without the intercept)
beta <- coef(fit, include_intercept=FALSE)
# beta <- fit$EW_beta
# plot estimated versus true beta
qplot(beta, data$beta) +
coord_fixed() + theme_bw()
## Warning: `qplot()` was deprecated in ggplot2 3.4.0.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
The estimated posterior inclusion probabilities per feature are
contained in the result list and can also be accessed using
getPIPs
# get estimated posterior inclusion probabilities per feature
pips <- getPIPs(fit)
# plot pips for zero versus non-zero features
df <- data.frame(pips = pips,
nonzero = data$beta != 0)
ggplot(df, aes(x=nonzero, y=pips, col=nonzero)) +
geom_jitter(height=0, width=0.2) +
theme_bw() + ylab("Posterior inclusion probability")
The function plotGroupPenalties
can be used to plot the
penalty factors and sparsity levels inferred for each feature group.
The function predict
can be used to make prediction on
new data. Here, we illustrate its use by predicting the response on the
test data defined above.
#SessionInfo
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: Etc/UTC
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] ggplot2_3.5.1 graper_1.23.0 BiocStyle_2.35.0
##
## loaded via a namespace (and not attached):
## [1] Matrix_1.7-1 gtable_0.3.6 jsonlite_1.8.9
## [4] crayon_1.5.3 compiler_4.4.2 BiocManager_1.30.25
## [7] Rcpp_1.0.13-1 jquerylib_0.1.4 scales_1.3.0
## [10] yaml_2.3.10 fastmap_1.2.0 lattice_0.22-6
## [13] R6_2.5.1 labeling_0.4.3 knitr_1.49
## [16] tibble_3.2.1 maketools_1.3.1 munsell_0.5.1
## [19] bslib_0.8.0 pillar_1.9.0 rlang_1.1.4
## [22] utf8_1.2.4 cachem_1.1.0 xfun_0.49
## [25] sass_0.4.9 sys_3.4.3 cli_3.6.3
## [28] withr_3.0.2 magrittr_2.0.3 digest_0.6.37
## [31] grid_4.4.2 cowplot_1.1.3 lifecycle_1.0.4
## [34] vctrs_0.6.5 evaluate_1.0.1 glue_1.8.0
## [37] farver_2.1.2 buildtools_1.0.0 fansi_1.0.6
## [40] colorspace_2.1-1 rmarkdown_2.29 matrixStats_1.4.1
## [43] tools_4.4.2 pkgconfig_2.0.3 htmltools_0.5.8.1