Use jaxspec
with external samplers¶
In this tutorial, we demonstrate how to use the vectorized and compilable log-likelihood and posterior log-probability
as implemented in jaxspec
in external samplers. We will use the zeus
sampler. Let's first import all the necessary packages and setup the basic configuration for JAX
.
%%capture
# Hide the output of this cell
import numpyro
numpyro.enable_x64()
numpyro.set_host_device_count(4)
numpyro.set_platform("cpu")
import arviz as az
import numpyro.distributions as dist
import matplotlib.pyplot as plt
import zeus
import jax
import numpy as np
from jaxspec.data.util import load_example_obsconf
from jaxspec.fit import BayesianModel, MCMCFitter
from jaxspec.model.additive import Powerlaw, Blackbodyrad
from jaxspec.model.multiplicative import Tbabs
Now, instead of the casual MCMCFitter
, we build a BayesianModel
, which is the parent class of any fitter implemented in jaxspec
.
spectral_model = Tbabs()*(Powerlaw() + Blackbodyrad())
prior = {
"powerlaw_1_alpha": dist.Uniform(0, 5),
"powerlaw_1_norm": dist.LogUniform(1e-5, 1e-2),
"blackbodyrad_1_kT": dist.Uniform(0, 5),
"blackbodyrad_1_norm": dist.LogUniform(1e-2, 1e2),
"tbabs_1_nh": dist.Uniform(0, 1)
}
ulx_observations = load_example_obsconf("NGC7793_ULX4_ALL")
bayesian_model = BayesianModel(spectral_model, prior, ulx_observations)
The BayesianModel
class exposes methods to compute the log-likelihood and posterior log-probability associated with any set of parameters. These functions are pure JAX
function, so they can be freely nested with the jit
or vmap
operators. In the following cell, we also use the array_to_dict
method to build a compiled and vectorized posterior log-probability function that accepts numpy
arrays as input. We can use it with the emcee
sampler and launch a MCMC run. To check the parameter order in the function we build, use the following line:
bayesian_model.parameter_names
['powerlaw_1_alpha', 'powerlaw_1_norm', 'blackbodyrad_1_kT', 'blackbodyrad_1_norm', 'tbabs_1_nh']
ndim, nwalkers = len(bayesian_model.parameter_names), 40
p0 = np.array([[2, 3e-4, 0.7, 0.2, 0.1]])*np.random.normal(loc=1, scale=0.1, size=(nwalkers, ndim))
@jax.jit
@jax.vmap
def zeus_log_prob(parameters: np.ndarray) -> jax.typing.ArrayLike:
return bayesian_model.log_posterior_prob(bayesian_model.array_to_dict(parameters))
sampler = zeus.EnsembleSampler(
nwalkers, ndim, zeus_log_prob, vectorize=True
)
sampler.run_mcmc(p0, 10_000, progress=True);
Initialising ensemble of 40 walkers... Sampling progress : 100%|██████████| 10000/10000 [04:30<00:00, 36.96it/s]
Note how we initialized the chains close to the expected values. Now let's check the convergence of the chain using arviz
.
inference_data_zeus= az.from_emcee(sampler, var_names=bayesian_model.parameter_names).sel(draw=slice(1_000, None))
with az.style.context("arviz-darkgrid", after_reset=True):
az.plot_trace(inference_data_zeus, compact=False, kind="rank_vlines")
plt.show()
Everything seems fine, but never forget that assessing convergence should not be as trivial as looking at the trace plot, check the good practices tutorial on this topic. Let's check if the results are agreeing with what we would obtain using the No U-Turn Sampler.
result = MCMCFitter(spectral_model, prior, ulx_observations).fit()
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
from chainconsumer import ChainConsumer, Chain, PlotConfig
chain_emcee = Chain.from_emcee(sampler, bayesian_model.parameter_names, "emcee")
chain_nuts = result.to_chain("nuts")
cc = ChainConsumer()
cc.set_plot_config(PlotConfig())
cc.add_chain(chain_emcee)
cc.add_chain(chain_nuts)
cc.plotter.plot_summary(errorbar=True, figsize=0.8)
for ax in plt.gcf().get_axes():
ax.tick_params(axis='x', labelrotation=45)
plt.show()
We got equivalent results using zeus
with few steps and a proper initialisation. The posterior log-probability function we built can be fed with any kind of numpy
or JAX
arrays, and should be easy to use with your favorite framework.