Skip to content

Integration

Module for integrating functions in a consistent way in jaxspec. It mainly relies on tanh-sinh (or double exponential) quadrature to perform the integration.

integrate_interval(integrand, n=51)

Build a function which can compute the integral of the provided integrand over the interval \([a, b]\) using the tanh-sinh quadrature. Returns a function \(F(a, b, \pmb{\theta})\) which takes the limits of the interval and the parameters of \(f(x,\pmb{\theta})\) as inputs.

\[ F(a, b, \pmb{\theta}) = \int_a^b f(x,\pmb{\theta}) \text{d}x \]

Example usage

pi = 4*integrate_interval(lambda x: 1/(1+x**2))(0, 1)
print(pi) # 3.1415927

Example where the limits of the integral are parameters

def erf(x):

    def integrand(t):
        return 2/jnp.sqrt(jnp.pi) * jnp.exp(-t**2)

    return integrate_interval(integrand)(0, x)

print(erf(1)) # 0.84270084

Parameters:

Name Type Description Default
integrand Callable

The function to integrate

required
n int

The number of points to use for the quadrature

51

Returns:

Type Description
Callable

The integral of the provided integrand over the interval \([a, b]\) as a callable

Source code in src/jaxspec/util/integrate.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def integrate_interval(integrand: Callable, n: int = 51) -> Callable:
    r"""
    Build a function which can compute the integral of the provided integrand over the interval $[a, b]$ using
    the tanh-sinh quadrature. Returns a function $F(a, b, \pmb{\theta})$ which takes the limits of the interval and
    the parameters of $f(x,\pmb{\theta})$ as inputs.

    $$
    F(a, b, \pmb{\theta}) = \int_a^b f(x,\pmb{\theta}) \text{d}x
    $$

    # Example usage

    ``` python
    pi = 4*integrate_interval(lambda x: 1/(1+x**2))(0, 1)
    print(pi) # 3.1415927
    ```

    # Example where the limits of the integral are parameters
    ``` python
    def erf(x):

        def integrand(t):
            return 2/jnp.sqrt(jnp.pi) * jnp.exp(-t**2)

        return integrate_interval(integrand)(0, x)

    print(erf(1)) # 0.84270084
    ```

    Parameters:
        integrand: The function to integrate
        n: The number of points to use for the quadrature

    Returns:
        The integral of the provided integrand over the interval $[a, b]$ as a callable
    """

    @jax.custom_jvp
    def f(a, b, *args):
        t, x, dx = interval_weights(a, b, n)

        return trapezoid(jnp.nan_to_num(integrand(x, *args) * dx), x=t)

    @f.defjvp
    def f_jvp(primals, tangents):
        a, b, *args = primals
        a_dot, b_dot, *args_dot = tangents

        t, x, dx = interval_weights(a, b, n)

        primal_out = f(a, b, *args)

        # Partial derivatives along other parameters
        jac = trapezoid(jnp.nan_to_num(jnp.asarray(jax.jacfwd(lambda args: integrand(x, *args))(args)) * dx), x=t, axis=-1)

        tangent_out = -integrand(a, *args) * a_dot + integrand(b, *args) * b_dot + jac @ jnp.asarray(args_dot)
        return primal_out, tangent_out

    return f

integrate_positive(integrand, n=51)

Build a function which can compute the integral of the provided integrand over the positive real line using the tanh-sinh quadrature. Returns a function \(F(\pmb{\theta})\) which takes the parameters of the integrand \(f(x,\pmb{\theta})\) as inputs.

\[ F(\pmb{\theta}) = \int_0^\infty f(x,\pmb{\theta}) \text{d}x \]

Example usage

gamma = integrate_positive(lambda t, z: t**(z-1) * jnp.exp(-t))
print(gamma(1/2)) # 1.7716383

Parameters:

Name Type Description Default
integrand Callable

The function to integrate

required
n int

The number of points to use for the quadrature

51

Returns:

Type Description
Callable

The integral of the provided integrand over the positive real line as a callable

Source code in src/jaxspec/util/integrate.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def integrate_positive(integrand: Callable, n: int = 51) -> Callable:
    r"""
    Build a function which can compute the integral of the provided integrand over the positive real line using
    the tanh-sinh quadrature. Returns a function $F(\pmb{\theta})$ which takes the parameters of the integrand
    $f(x,\pmb{\theta})$ as inputs.

    $$
    F(\pmb{\theta}) = \int_0^\infty f(x,\pmb{\theta}) \text{d}x
    $$

    # Example usage

    ``` python
    gamma = integrate_positive(lambda t, z: t**(z-1) * jnp.exp(-t))
    print(gamma(1/2)) # 1.7716383
    ```

    Parameters:
        integrand: The function to integrate
        n: The number of points to use for the quadrature

    Returns:
        The integral of the provided integrand over the positive real line as a callable
    """

    @jax.custom_jvp
    def f(*args):
        t, x, dx = positive_weights(n)

        return trapezoid(jnp.nan_to_num(integrand(x, *args) * dx), x=t)

    @f.defjvp
    def f_jvp(primals, tangents):
        args = primals
        args_dot = tangents

        t, x, dx = positive_weights(n)

        primal_out = f(*args)

        # Partial derivatives along other parameters
        jac = trapezoid(jnp.nan_to_num(jnp.asarray(jax.jacfwd(lambda args: integrand(x, *args))(args)) * dx), x=t, axis=-1)

        tangent_out = jac @ jnp.asarray(args_dot)
        return primal_out, tangent_out

    return f

interval_weights(a, b, n)

Return the weights for the tanh-sinh quadrature over the interval [a, b].

Source code in src/jaxspec/util/integrate.py
19
20
21
22
23
24
25
26
27
28
29
def interval_weights(a: float, b: float, n: int) -> tuple[Array, Array, Array]:
    """
    Return the weights for the tanh-sinh quadrature over the interval [a, b].
    """
    t = jnp.linspace(-3, 3, n)
    phi = jnp.tanh(jnp.pi / 2 * jnp.sinh(t))
    dphi = jnp.pi / 2 * jnp.cosh(t) * (1 / jnp.cosh(jnp.pi / 2 * jnp.sinh(t)) ** 2)
    x = (b - a) / 2 * phi + (b + a) / 2
    dx = (b - a) / 2 * dphi

    return t, x, dx

positive_weights(n)

Return the weights for the tanh-sinh quadrature over the positive real line.

Source code in src/jaxspec/util/integrate.py
32
33
34
35
36
37
38
39
40
def positive_weights(n: int) -> tuple[Array, Array, Array]:
    """
    Return the weights for the tanh-sinh quadrature over the positive real line.
    """
    t = jnp.linspace(-3, 3, n)
    x = jnp.exp(jnp.pi / 2 * jnp.sinh(t))
    dx = jnp.pi / 2 * jnp.cosh(t) * jnp.exp(jnp.pi / 2 * jnp.sinh(t))

    return t, x, dx