Skip to content

Cookbook : how do I ...

Fit observations with MCMC

This is the example we use in the jaxspec paper.

import numpyro

numpyro.enable_x64()
numpyro.set_host_device_count(4)
numpyro.set_platform("cpu")

import numpyro.distributions as dist
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jaxspec.data.util import load_example_obsconf
from jaxspec.fit import MCMCFitter
from jaxspec.model.additive import Powerlaw, Blackbodyrad
from jaxspec.model.multiplicative import Tbabs

spectral_model = Tbabs() * (Powerlaw() + Blackbodyrad())

prior = {
    'powerlaw_1': {
        'alpha': dist.Uniform(0. * jnp.ones((3,)), 5 * jnp.ones((3,))),
        'norm': dist.LogUniform(1e-6, 1e-3)
    },
    'blackbodyrad_1': {
        'kT': dist.Uniform(0.3, 3),
        'norm': dist.LogUniform(1e-2, 1e3)
    },
    'tbabs_1': {
        'N_H': 0.2
    }
}

ulx_observations = load_example_obsconf()
fitter = MCMCFitter(spectral_model, prior, ulx_observations)
result = fitter.fit(num_samples=1_000)

Evaluate the true model

You should look at SpectralModel.photon_flux and SpectralModel.energy_flux methods.

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxspec.model.additive import Blackbodyrad
from jaxspec.model.multiplicative import Tbabs

spectral_model = Tbabs() * Blackbodyrad()

energies = jnp.geomspace(1, 50, 100)

params = {
    'blackbodyrad_1': {
        'kT': 1.,
        'norm': 1.
    },
    'tbabs': {
        'nH': 1.
    }
}

photon_flux = spectral_model.photon_flux(params, energies[:-1], energies[1:], n_points=30)
energy_flux = spectral_model.energy_flux(params, energies[:-1], energies[1:], n_points=30)

Compute model photon flux, energy flux and luminosity

You should look at FitResult.photon_flux, FitResult.energy_flux, and FitResult.luminosity

Save and load inference results

You can use the dill package to serialise and un-serialise such objects. First you should install it using pip

pip install dill

Then use the following lines to save and load the files:

import dill

# Save the results
with open(r"result.pickle", "wb") as output_file:
    dill.dump(result, output_file)

# Load the results
with open(r"result.pickle", "rb") as input_file:
    result_pickled = dill.load(input_file)