Skip to content

Generate mock data

Basic usage

This tutorial illustrates how to make generate mock observed spectra using fakeit - like interface as proposed by XSPEC.

import numpyro

numpyro.enable_x64()
numpyro.set_platform("cpu")

Let's build a model we want to fake and load an observation with the instrumental setup which should be applied

from jaxspec.model.additive import Powerlaw, Blackbodyrad
from jaxspec.model.multiplicative import Tbabs
from jaxspec.data import ObsConfiguration

obsconf = ObsConfiguration.from_pha_file('obs_1.pha')
model = Tbabs() * (Powerlaw() + Blackbodyrad())

Let's do fakeit for a bunch of parameters

from numpy.random import default_rng

rng = default_rng(42)

num_params = 10000

parameters = {
    "tbabs_1_nh": rng.uniform(0.1, 0.4, size=num_params),
    "powerlaw_1_alpha": rng.uniform(1, 3, size=num_params),
    "powerlaw_1_norm": rng.exponential(10 ** (-0.5), size=num_params),
    "blackbodyrad_1_kT": rng.uniform(0.1, 3.0, size=num_params),
    "blackbodyrad_1_norm": rng.exponential(10 ** (-3), size=num_params)
}

And now we can fakeit!

from jaxspec.data.util import fakeit_for_multiple_parameters

spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)

Let's plot some of the resulting spectra

import matplotlib.pyplot as plt

plt.figure(figsize=(5,4))

for i in range(10):

    plt.step(
        obsconf.out_energies[0],
        spectra[i, :],
        where="post"
    )

plt.xlabel("Energy [keV]")
plt.ylabel("Counts")
plt.loglog()

Some spectra

Using only the instrument

If you don't have any observation you can use as a reference, you can still build a mock ObsConfiguration using the instrument you want to use.

from jaxspec.data import ObsConfiguration, Instrument

instrument = Instrument.from_ogip_file(
    "instrument.rmf",
    arf_path="instrument.arf"
)

obsconf = ObsConfiguration.mock_from_instrument(
    instrument,
    exposure=1e5,
)

Then you can use this ObsConfiguration within fakeit_for_multiple_parameters as before.

spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)

Computing in parallel

Thanks to the amazing PositionalSharding interface, it is fairly easy to do this computation in parallel using a sharding on the input parameters. To do so, one must first declare multiple devices using one of the following codes.

import numpyro

n_devices = 8

numpyro.set_platform("cpu")
numpyro.set_host_device_count(n_devices)
numpyro.enable_x64()
import os
import jax

n_devices = 8

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={n_devices}"
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

This must be run before any JAX code is run in the process otherwise the extra cores won't be accessible. To double-check, you can ensure that the available number of devices is consistent with n_devices

assert len(jax.local_devices()) == n_devices

Once it is certain that all the devices are visible, the array can be split using a PositionalSharding and distributed to all the devices.

import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))

# Split the parameters on every device

sharded_parameters = jax.device_put(parameters, sharding)

Then we can use these sharded parameters to compute the fakeits in parallel

fakeit_for_multiple_parameters(obsconf, model, sharded_parameters, apply_stat=False)

Info

Since JAX is pretty well optimized at getting the best from your physical CPU, the gain by enforcing parallel execution might not be gigantic. However, doing so on several GPUs or TPUs will greatly improve the execution time.