How Efficient is Stan Compared to JAGS?

Conjugacy, pooling, centering, and posterior correlations

For a good while JAGS was your best bet if you wanted to do MCMC on a distribution of your own choosing. Then Stan came along, potentially replacing JAGS as the black-box sampler of choice for many Bayesians. But how do they compare in terms of performance? The obvious answer is: It depends. In fact, the question is nearly impossible to answer properly, as any comparison will be conditional on the data, model specifications, test criteria, and more. Nevertheless, this post offers a small simulation study.

The approach taken here is to compare the models in terms of seconds required to produce an effective simulation draw for the parameter with the least effective draws. Simulation time is taken to exclude compilation time, but include adaptation and sufficient warm-up for the chains to get into the typical set. The samplers are tested in a hierarchical setting, using six different models. The models differ in terms of pooling, conjugacy, and centered parameter specifications. The results are roughly as expected: JAGS exhibits both blazing fast and incredibly slow sampling, while Stan delivers somewhat more stable performance, being relatively efficient also in harder scenarios.

Data

This mini-study generates hierarchical datasets, containing \(m\) noisy observations for each of \(n\) individuals. For a given individual \(i \in \{1, \ldots, n\}\) and trial number \(j \in \{1, \ldots, m\}\), we have the observed outcome \(y_{ij}\). In addition to \(y\), the data contain the covariate \(x\), which also varies by \(i\) and \(j\). The outcome \(y\) will have normally distributed, homoskedastic errors, and depend linearly on \(x\), according to the following process: \(y_{ij} \sim N(\alpha_i + \beta_i x_{ij},\sigma^2)\). Each individual will have its own true intercept, \(\alpha_i\), as well as its own coefficient on \(x\), \(\beta_i\). These individual-specific parameters are in turn drawn from their own distributions: \(\alpha \sim N(2,1)\), and \(\beta \sim N(-2,1)\). The code to define the data-generating process (and load necessary packages) is:

library(rstan)
library(rjags)
library(MASS)   # To sample from multivariate normal
library(dclone) # To run JAGS in parallel
set.seed(1)
create.data <- function(n = 500, m = 5, covar = 0, xmean = 0) {
  Sigma <- matrix(c(1, covar, covar, 1), 2, 2)
  coefs <- mvrnorm(n, mu = c(2, -2), Sigma)
  x <- matrix(rnorm(n * m, xmean, 1), ncol = m)
  y <- matrix(NA, n, m)
  for (i in 1:n) { for (j in 1:m) { y[i, j] <- rnorm(1, coefs[i, 1] + coefs[i, 2] * x[i, j], 2) } }
  dat <- list(n = n, m = m, x = x, y = y)
  return(dat)
}

Posterior parameter correlations

A key issue for MCMC samplers is how strong posterior correlations there are between the parameters, which in turn depends on both the data and the model specification. Strong correlations mean the parameters are hard to separate, posing a more difficult task for the samplers. JAGS, which relies on either Gibbs sampling or Metropolis-Hastings, is likely to suffer under such circumstances, while Stan might be expected perform better. With the present data, we can minimize posterior correlations by centering \(x\). Conversely, we can create correlations by moving its mean away from zero. The larger the mean relative to the standard deviation, the harder it will be to separate the constants from the coefficients. The tests below entail two scenarios that differ in this respect. One scenario with “weak correlations”: \(x_{ij} \sim N(0,1)\), and one with “strong correlations”: \(x_{ij} \sim N(2,1)\).

Models

We will use a likelihood function that accurately reflects the data-generating process. In other words, we have:

\[p(y|x,\alpha,\beta,\sigma) = \prod^n_{i=1}\prod^m_{j=1} N(y_{ij}|\alpha_i + \beta_i x_{ij}, \sigma^2)\]

We will test five different models to see how the results differ across key specification choices. The main choices are: whether to do some pooling of individual coefficients, whether to use conjugate priors, and whether to use non-centered parameter specifications.

Hierarchical, weakly informative, fully conjugate (centered and non-centered)

We will start with a model doing partial pooling, assuming the \(\alpha\)’s and \(\beta\)’s are drawn from common population distributions: \[\alpha_i \sim N(\mu_\alpha,\sigma_\alpha^2), ~~~ \beta_i \sim N(\mu_\beta,\sigma_\beta^2)\] This model is fully conjugate, with normal priors on the means and inverse Gamma priors on the variances. (Note that the last parameter in the inverse gamma here is a scale and not a rate parameter.) The hyperpriors and the prior on \(\sigma\) are set to be weakly informative: \[\mu_\alpha,\mu_\beta \sim N(0,3^2), ~~~ \sigma_\alpha^2,\sigma_\beta^2,\sigma^2 \sim \text{Inv-Gamma}(1.5,2)\]

In some situations (with little data and/or correlated parameters), one can achieve more efficient and accurate sampling by using a non-centered parameterization (see Papaspiliopoulos et al. 2007, Betancourt and Girolami 2013). For the \(\alpha\)-parameters, this can be done by specifying \(\alpha_i = \mu_\alpha + \sigma_\alpha \alpha^*_i\), where \(\alpha^*_i \sim N(0,1)\). The tests below entail both a centered and non-centered version of the conjugate model. (It is tricky to do non-centering without breaking JAGS’ ability to exploit the conjugacy, so the non-centered JAGS version only separates out the common mean).

Hierarchical, weakly informative, partly conjugate (centered and non-centered)

The next model is very similar, but only partly conjugate, as the inverse Gamma priors on the variances are replaced with gamma priors on the standard deviations. The new priors are set to be approximately as informative as the inverse Gamma’s (although they are necessarily different). The model is then:

\[\alpha_i \sim N(\mu_\alpha,\sigma_\alpha^2), ~~~ \beta_i \sim N(\mu_\beta,\sigma_\beta^2)\]

\[\mu_\alpha,\mu_\beta \sim N(0,3^2), ~~~ \sigma_\alpha,\sigma_\beta,\sigma \sim \text{Gamma}(2,.5)\]

(Note that the last parameter in the gamma is a rate parameter, i.e. inverse scale).

Hierarchical, weakly informative, non-conjugate, centered

The next model is completely non-conjugate, replacing the normal distributions with student-t distributions with 100 degrees of freedom. This is practically the same as the normal, but should prevent JAGS from using Gibbs sampling.

\[\alpha_i \sim \text{Student-t}(100,\mu_\alpha,\sigma_\alpha^2), ~~~ \beta_i \sim \text{Student-t}(100,\mu_\beta,\sigma_\beta^2)\]

\[\mu_\alpha,\mu_\beta \sim \text{Student-t}(100,0,3^2), ~~~ \sigma_\alpha,\sigma_\beta,\sigma \sim \text{Gamma}(2,.5)\]

Unpooled with uniform priors

The final model is an unpooled model with uniform priors on all parameters. This model does not make too much sense here (as we know the individual parameters are drawn form a common distribution), but it may serve to illustrate performance in more difficult circumstances. The model is potentially challenging in that there is no conjugacy, no pooling of information across units, and no other help from the priors, which results in larger posterior variances (and potentially covariances): \[\alpha_i,\beta_i \sim U(-100,100), ~~~ \sigma \sim U(0,100)\]

This model only takes a few lines of Stan code, as we get uniform priors by not specifying anything (except relevant limits).

Approach

A key question is how to compare the samplers in way that is both fair and relevant for actual applications. As MCMC draws tend to be auto-correlated, the nominal number of simulation draws is less interesting than the number of effective draws (adjusting for auto-correlation). Furthermore, if we decide on a lower limit for how many effective draws we require, we will most likely want to apply this requirement to all parameters. In other words, the key question is how efficiently we are sampling the parameter with the fewest effective draws (after the chains have found the typical set). Finally, the most relevant measure of efficiency is probably time per effective draw.

However, it is not clear what is the most relevant specification of time-consumption either. Stan needs to compile models before running, while JAGS does not. Fortunately, Stan models only need to be compiled once, and it usually takes a minute or two. If you are using an already-compiled model, there is no loss of time. Similarly, if your model is sufficiently complicated to require a long run, compilation time is essentially irrelevant. There is really only one situation in which compilation-time counts, and that is when you are developing a new model, testing it on a small dataset to make things faster, but still having to re-compile for every little change you to make. Given the points above, however, I think it is most reasonable to compare run-times without including compilation-time.

Another issue is whether to include adaptation and warm-up time. I think it is reasonable to include this, as these phases are typically required for any analysis. If a sampler takes very long to converge to the target distribution, this is a relevant cost of using it. Including these phases may, however, introduce some inaccuracy: As we take more post-warm-up draws, the share of time spent on warm-up decreases. This makes the results depend on the length of the post-warm-up run, which is clearly not ideal, but I am still sticking to this approach for now. (I suppose the best approach might be to set the run-length for each model aiming for a specific minimum number of effective draws for all parameters, but this might take extremely long in the slowest instances.)

The models have been fit to 200 datasets for each of the two posterior-correlation scenarios. The datasets have been kept fairly small: \(n = 250\) and \(m=5\). The patterns reported below seem to remain when I scale up the data, but further testing may be called for (as the speed of JAGS may suffer more as the scale increases). For each software, I use three out of four available processor cores, and run one parallel chain on each core. To calculate split Rhats and the number of effective draws, I am using the specifications from Gelman et al. 2014 (BDA3). Setting the length of the warm-up phase is a bit tricky, as a longer warm-up is required for the more challenging combinations of data and models. By trial and error, I have set specific warm-up lengths for each model in each scenario so that the models typically converge. In the most challenging situations, the models will sometimes still not fully converge in time, and in these instances the results are not reported. (This may slightly favor models that are performing poorly, but it should not change the general patterns.)

Results

The results are shown on a logged scale in the boxplot below. The scenario with weak posterior correlations offers ideal conditions, and when JAGS is given a fully conjugate hierarchical model, it is sampling very fast here – showing the fastest performance in this test. Stan is also fast in this situation, but does take about 3 times as long JAGS. The non-centered parameterization appears to slow down both samplers slightly, suggesting it adds to the computational load without much gain. When we move to the partly conjugate models, JAGS gets notably slower, and about as fast as Stan, whether we use the centered or non-centered specification.

Moving from a partly conjugate to completely non-conjugate model (replacing normals with t-distributions), makes JAGS about three times slower still, while Stan retains its performance, being several times faster. Turning to the unpooled models with uniform priors, Stan gets slightly faster, while JAGS gets a bit slower – now sampling nearly five times slower than Stan.

As the posterior correlations increase, the patterns are similar, but with some notable changes: JAGS gets comparatively slower for the fully conjugate models, but still delivers some of the fastest performance. It is now on par with Stan’s performance for the non-centered specification. If we look at the partly conjugate models, non-centering speeds up Stan seven-fold in this scenario, while JAGS sees no such benefit. Finally, for the unpooled model with uniform priors, Stan is pretty fast, while JAGS is very slow – in fact, JAGS is 65 times slower than Stan here.

In sum, JAGS’s performance varies a lot: It is about 765 times faster using a fully conjugate, centered model in the easy scenario than using an unpooled model in the harder scenario. For Stan, the equivalent factor is 4.

Another question is how well the samplers are exploring the target distribution. The paper on NUTS by Hoffman and Gelman (2014) shows that NUTS can perform much better than Gibbs sampling and Random Walk Metropolis in exploring difficult posteriors with correlated parameters. At the most extreme, the authors use a highly correlated 250-dimensional distribution, whereas the data used here mainly entail pairwise parameter correlations (between pairs of \(\alpha\)’s and \(\beta\)’s, and between their hyperparameters in the partially pooled models). The plots below show 1,000 randomly selected draws of \(\alpha_1\) and \(\beta_1\) from a single trial using unpooled models with uniform priors. Visually, the samples produced by JAGS and Stan are very similar here. It is just that JAGS takes much longer to converge and produce effective draws.

Final notes

A key lesson to draw from these results is that model specification matters. For instance, a non-centered specification can be of great help to Stan, and when this is the case, there is no reason not to go this route. For JAGS, it is obviously a good idea to use conjugate priors wherever possible, and thus avoid turning JAGS into Just Another Metropolis Sampler. Another point to note is that centering the predictor \(x\) would have turned the scenario with high posterior correlations into one with low correlations – we could actually have avoided the posterior correlations altogether.

In addition to efficiency, there are other considerations that may be relevant for choosing between JAGS and Stan. Stan has a large development team and a very active user group, which makes it easy to solicit expert advice. Stan is also not reliant on a specific type of priors (conjugacy) and can handle a wide range of models. Furthermore, Stan is likely explore complicated posterior distributions better, and thus give more accurate results (see references below). While JAGS can be very fast for certain well-suited problems, Stan may offer the best all-round performance, especially if you are analyzing large datasets using complicated models.

How to cite

This material can be cited as follows:

Bølstad, Jørgen. 2019. “How Efficient is Stan Compared to JAGS? Conjugacy, Pooling, Centering, and Posterior Correlations”. Playing with Numbers: Notes on Bayesian Statistics. www.boelstad.net/post/stan_vs_jags_speed/.

Here is a BibTex-formatted reference, which should work when using the natbib and hyperref packages together with a suitable Latex style:

@Misc{boelstad:2019a,
  Author = {B{\o}lstad, J{\o}rgen},
  Title = {How Efficient is Stan Compared to JAGS? Conjugacy, Pooling, Centering, and Posterior Correlations},
  Howpublished = {Playing with Numbers: Notes on Bayesian Statistics},
  Url = {\url{http://www.boelstad.net/post/stan_vs_jags_speed/}},
  Year = {2019},
  Month = {January 2}}

Further reading

Betancourt, Michael, and Mark Girolami. 2013. “Hamiltonian Monte Carlo for hierarchical models”. arXiv:1312.0906.

Betancourt, Michael. 2017. “A conceptual introduction to Hamiltonian Monte Carlo”. arXiv:1701.02434v2.

Gelman, Andrew, John B Carlin, Hal S Stern, David B Dunson, Aki Vehtari and Donald B Rubin. 2014. Bayesian Data Analysis. 3rd ed. London: Chapman & Hall/CRC Press.

Hoffman, Matthew D., and Andrew Gelman. 2014. “The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo”. Journal of Machine Learning Research, 15(1), pp.1593-1623.

Neal, Radford M. 2011. “MCMC Using Hamiltonian Dynamics”. In Steve Brooks, Andrew Gelman, Galin L. Jones, Xiao-Li Meng. Handbook of Markov Chain Monte Carlo. London: Chapman & Hall/CRC Press.

Papaspiliopoulos, Omiros, Gareth O. Roberts, and Martin Sköld. 2007. “A general framework for the parametrization of hierarchical models”. Statistical Science, 59-73.


comments powered by Disqus