Good practices for analysing MCMC results using arviz
¶
Assessing the convergence of a set of MCMC chains is not an easy task in general. jaxspec
provides a convenient way to analyse the results of a fit using the arviz
library. This library provides powerful tool to explore Bayesian inference results, such as trace plots, pair plots, and summary statistics. Let's run some MCMCs!
%%capture
# Hide the output of this cell
import numpyro
numpyro.enable_x64()
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
import numpyro.distributions as dist
from jaxspec.fit import MCMCFitter
from jaxspec.data.util import load_example_obsconf
from jaxspec.model.additive import Blackbodyrad, Powerlaw
from jaxspec.model.multiplicative import Tbabs
spectral_model = Tbabs()*(Powerlaw() + Blackbodyrad())
obsconf = load_example_obsconf("NGC7793_ULX4_PN")
prior = {
"powerlaw_1_alpha": dist.Uniform(1, 3),
"powerlaw_1_norm": dist.LogUniform(1e-5, 1e-3),
"blackbodyrad_1_kT": dist.Uniform(0, 2),
"blackbodyrad_1_norm": dist.LogUniform(1e-2, 1),
"tbabs_1_nh": dist.Uniform(0, 1)
}
fitter = MCMCFitter(spectral_model, prior, obsconf)
result_nuts = fitter.fit(
num_chains=4,
num_warmup=1000,
num_samples=1000,
sampler="nuts",
mcmc_kwargs={"progress_bar": True}
)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
From the result object, you can access the inference_data
attribute, which is an arviz.InferenceData
object. This leverage the use of every arviz function to analyse the results of the fit.
inference_data = result_nuts.inference_data
inference_data
-
<xarray.Dataset> Size: 168kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: blackbodyrad_1_kT (chain, draw) float64 32kB 0.7618 0.7805 ... 0.7203 blackbodyrad_1_norm (chain, draw) float64 32kB 0.1917 0.154 ... 0.2119 powerlaw_1_alpha (chain, draw) float64 32kB 2.05 2.116 ... 1.957 1.979 powerlaw_1_norm (chain, draw) float64 32kB 0.0002669 ... 0.0002556 tbabs_1_nh (chain, draw) float64 32kB 0.1076 0.1352 ... 0.08894 Attributes: created_at: 2024-11-22T15:30:08.733726+00:00 arviz_version: 0.20.0 model: SpectralModel(\n graph=<networkx.classes.digraph.DiGraph...
-
<xarray.Dataset> Size: 3MB Dimensions: (chain: 4, draw: 1000, obs_data_dim_0: 102) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 * obs_data_dim_0 (obs_data_dim_0) int64 816B 0 1 2 3 4 5 ... 97 98 99 100 101 Data variables: obs_data (chain, draw, obs_data_dim_0) int64 3MB 206 202 ... 18 18 Attributes: created_at: 2024-11-22T15:30:08.734722+00:00 arviz_version: 0.20.0 model: SpectralModel(\n graph=<networkx.classes.digraph.DiGraph...
-
<xarray.Dataset> Size: 3MB Dimensions: (chain: 4, draw: 1000, obs_data_dim_0: 102) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 * obs_data_dim_0 (obs_data_dim_0) int64 816B 0 1 2 3 4 5 ... 97 98 99 100 101 Data variables: obs_data (chain, draw, obs_data_dim_0) float64 3MB -9.419 ... -2.611 Attributes: created_at: 2024-11-22T15:30:08.734448+00:00 arviz_version: 0.20.0 model: SpectralModel(\n graph=<networkx.classes.digraph.DiGraph...
-
<xarray.Dataset> Size: 192kB Dimensions: (chain: 1, draw: 4000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 32kB 0 1 2 3 4 ... 3996 3997 3998 3999 Data variables: blackbodyrad_1_kT (chain, draw) float64 32kB 0.0196 1.745 ... 1.102 blackbodyrad_1_norm (chain, draw) float64 32kB 0.03938 0.01347 ... 0.01858 powerlaw_1_alpha (chain, draw) float64 32kB 1.235 1.731 ... 2.349 2.549 powerlaw_1_norm (chain, draw) float64 32kB 0.0001388 ... 2.465e-05 tbabs_1_nh (chain, draw) float64 32kB 0.3542 0.2841 ... 0.1103 Attributes: created_at: 2024-11-22T15:30:08.735628+00:00 arviz_version: 0.20.0 model: SpectralModel(\n graph=<networkx.classes.digraph.DiGraph...
-
<xarray.Dataset> Size: 2kB Dimensions: (obs_data_dim_0: 102) Coordinates: * obs_data_dim_0 (obs_data_dim_0) int64 816B 0 1 2 3 4 5 ... 97 98 99 100 101 Data variables: obs_data (obs_data_dim_0) int64 816B 265 227 196 209 ... 20 20 20 20 Attributes: created_at: 2024-11-22T15:30:08.736478+00:00 arviz_version: 0.20.0 model: SpectralModel(\n graph=<networkx.classes.digraph.DiGraph...
This object carries all the information we need about our fit : the posterior samples, the associated likelihood, posterior predictives and so on. They can be a bit tricky to manipulate at first, as they are based on the xarray
, but they provide a lot of flexibility to analyse the results of the fit.
Trace plot¶
This visualization is useful to see the evolution of the parameters during the sampling process. It can be used to diagnose convergence issues. The ideal situation is when the chains are well mixed and randomly scattered around the target distribution. If instead, chains are stuck in some region of the parameter space, or show some trends, this might indicate that the sampler did not explore the full parameter space.
import arviz as az
import matplotlib.pyplot as plt
with az.style.context("arviz-darkgrid", after_reset=True):
az.plot_trace(inference_data, compact=False)
plt.show()
However, this kind of plot can become messy pretty quickly if you work with numerous chains/walkers, which what you must do when using Ensemble Samplers such as ESS or AIES. A general trick proposed by Vehtari et al. (2019) is to plot the rank of each sample in the global run instead of its value. This rank should be evenly spread across the steps if the chains are well mixed. This can be done using the kind="rank_vlines"
argument in the plot_trace
function.
with az.style.context("arviz-darkgrid", after_reset=True):
az.plot_trace(result_nuts.inference_data, compact=False, kind="rank_vlines")
plt.show()
A more quantitative way to assess the convergence of the chains is to use the summary
function. This function provides a summary of the posterior distribution of the parameters, including the mean, the standard deviation, and the 95% highest posterior density interval.
az.summary(result_nuts.inference_data.posterior)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
blackbodyrad_1_kT | 0.743 | 0.035 | 0.679 | 0.808 | 0.001 | 0.001 | 631.0 | 1056.0 | 1.01 |
blackbodyrad_1_norm | 0.200 | 0.037 | 0.134 | 0.269 | 0.002 | 0.001 | 596.0 | 1040.0 | 1.01 |
powerlaw_1_alpha | 2.029 | 0.100 | 1.836 | 2.205 | 0.004 | 0.003 | 630.0 | 1050.0 | 1.01 |
powerlaw_1_norm | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 604.0 | 967.0 | 1.01 |
tbabs_1_nh | 0.095 | 0.028 | 0.043 | 0.146 | 0.001 | 0.001 | 592.0 | 796.0 | 1.01 |
The r_hat
column provides a measure of the splitted Gelman-Rubin statistic with rank normalisartion. The closer this value is to 1, the better. A value larger than 1.01 point to convergence issues. This statistic can be directly computed using the r_hat
function, see Vehtari et al. (2019). The ess
denotes the Effective Sample Size of the chains, which is a measure of the quality of the samples. The larger the better, in general, we want this value to be larger than 400 for a reliable estimate of the posterior distribution using the NUTS sampler with 1000 warmups and 1000 samples.
rhat = az.rhat(result_nuts.inference_data.posterior)
rhat
<xarray.Dataset> Size: 40B Dimensions: () Data variables: blackbodyrad_1_kT float64 8B 1.009 blackbodyrad_1_norm float64 8B 1.007 powerlaw_1_alpha float64 8B 1.008 powerlaw_1_norm float64 8B 1.008 tbabs_1_nh float64 8B 1.008
ess = az.ess(result_nuts.inference_data.posterior)
ess
<xarray.Dataset> Size: 40B Dimensions: () Data variables: blackbodyrad_1_kT float64 8B 631.1 blackbodyrad_1_norm float64 8B 595.9 powerlaw_1_alpha float64 8B 630.1 powerlaw_1_norm float64 8B 603.5 tbabs_1_nh float64 8B 591.8
Pair plot¶
This visualization is useful to see the correlation between the parameters. The ideal situation is when the parameters are uncorrelated, which means that the posterior distribution is close to a multivariate Gaussian distribution.
with az.style.context("arviz-darkgrid", after_reset=True):
az.plot_pair(result_nuts.inference_data)
plt.show()
Take a look at arviz
's documentation to see what else you can do with this library.