Beispiel #1
0
def test_build_tree(step_size):
    def kinetic_fn(m_inv, p):
        return 0.5 * np.sum(m_inv * p**2)

    def potential_fn(q):
        return 0.5 * q**2

    vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    vv_state = vv_init(0.0, 1.0)
    inverse_mass_matrix = np.array([1.])
    rng_key = random.PRNGKey(0)

    @jit
    def fn(vv_state):
        tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix,
                          step_size, rng_key)
        return tree

    tree = fn(vv_state)

    assert tree.num_proposals >= 2**(tree.depth - 1)

    assert tree.sum_accept_probs <= tree.num_proposals

    if tree.depth < 10:
        assert tree.turning | tree.diverging

    # for large step_size, assert that diverging will happen in 1 step
    if step_size > 10:
        assert tree.diverging
        assert tree.num_proposals == 1

    # for small step_size, assert that it should take a while to meet the terminate condition
    if step_size < 0.1:
        assert tree.num_proposals > 10
Beispiel #2
0
 def get_final_state(model, step_size, num_steps, q_i, p_i):
     vv_init, vv_update = velocity_verlet(model.potential_fn, model.kinetic_fn)
     vv_state = vv_init(q_i, p_i)
     q_f, p_f, _, _ = fori_loop(
         0, num_steps, lambda i, val: vv_update(step_size, args.m_inv, val), vv_state
     )
     return (q_f, p_f)
Beispiel #3
0
    def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args,
                  model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update, forward_mode_ad
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)

        num_steps = _get_num_steps(step_size, trajectory_len)
        vv_state_new = fori_loop(
            0, num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition, (vv_state_new, energy_new),
                                identity, (vv_state, energy_old), identity)
        return vv_state, energy, num_steps, accept_prob, diverging
Beispiel #4
0
    def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args,
                  model_kwargs, rng_key, trajectory_length):
        if potential_fn_gen:
            nonlocal vv_update, forward_mode_ad
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)

        # no need to spend too many steps if the state z has 0 size (i.e. z is empty)
        if len(inverse_mass_matrix) == 0:
            num_steps = 1
        else:
            num_steps = _get_num_steps(step_size, trajectory_length)
        # makes sure trajectory length is constant, rather than step_size * num_steps
        step_size = trajectory_length / num_steps
        vv_state_new = fori_loop(
            0, num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition, (vv_state_new, energy_new),
                                identity, (vv_state, energy_old), identity)
        return vv_state, energy, num_steps, accept_prob, diverging
Beispiel #5
0
    def _nuts_next(step_size, inverse_mass_matrix, vv_state,
                   model_args, model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn)

        binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
                                 inverse_mass_matrix, step_size, rng_key,
                                 max_delta_energy=max_delta_energy,
                                 max_tree_depth=max_treedepth)
        accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
        num_steps = binary_tree.num_proposals
        vv_state = IntegratorState(z=binary_tree.z_proposal,
                                   r=vv_state.r,
                                   potential_energy=binary_tree.z_proposal_pe,
                                   z_grad=binary_tree.z_proposal_grad)
        return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
Beispiel #6
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    inverse_mass_matrix=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    find_heuristic_step_size=False,
                    model_args=(),
                    model_kwargs=None,
                    rng_key=random.PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
            step size at the beginning of each adaptation window. Defaults to False.
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.

        """
        step_size = lax.convert_element_type(step_size,
                                             canonicalize_dtype(jnp.float64))
        nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe_fn = potential_fn
        if potential_fn_gen:
            if pe_fn is not None:
                raise ValueError(
                    'Only one of `potential_fn` or `potential_fn_gen` must be provided.'
                )
            else:
                kwargs = {} if model_kwargs is None else model_kwargs
                pe_fn = potential_fn_gen(*model_args, **kwargs)

        find_reasonable_ss = None
        if find_heuristic_step_size:
            find_reasonable_ss = partial(find_reasonable_step_size, pe_fn,
                                         kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
        z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
        wa_state = wa_init(z_info,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix,
                           mass_matrix_size=jnp.size(ravel_pytree(z)[0]))
        r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
        vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn)
        vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)
        return device_put(hmc_state)
Beispiel #7
0
    def init_kernel(
            init_params,
            num_warmup,
            *,
            step_size=1.0,
            inverse_mass_matrix=None,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
            target_accept_prob=0.8,
            trajectory_length=2 * math.pi,
            max_tree_depth=10,
            find_heuristic_step_size=False,
            forward_mode_differentiation=False,
            regularize_mass_matrix=True,
            model_args=(),
            model_kwargs=None,
            rng_key=random.PRNGKey(0),
    ):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
            For a potential_fn with general JAX pytree parameters, the order of entries
            of the mass matrix is the order of the flattened version of pytree parameters
            obtained with `jax.tree_flatten`, which is a bit ambiguous (see more at
            https://jax.readthedocs.io/en/latest/pytrees.html). If `model` is not None,
            here we can specify a structured block mass matrix as a dictionary, where
            keys are tuple of site names and values are the corresponding block of the
            mass matrix.
            For more information about structured mass matrix, see `dense_mass` argument.
        :type inverse_mass_matrix: numpy.ndarray or dict
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param dense_mass:  This flag controls whether mass matrix is dense (i.e. full-rank) or
            diagonal (defaults to ``dense_mass=False``). To specify a structured mass matrix,
            users can provide a list of tuples of site names. Each tuple represents
            a block in the joint mass matrix. For example, assuming that the model
            has latent variables "x", "y", "z" (where each variable can be multi-dimensional),
            possible specifications and corresponding mass matrix structures are as follows:

                + dense_mass=[("x", "y")]: use a dense mass matrix for the joint
                  (x, y) and a diagonal mass matrix for z
                + dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass
                  matrix for the joint (x, y, z)
                + dense_mass=[("x", "y", "z")] (equivalent to full_mass=True):
                  use a dense mass matrix for the joint (x, y, z)
                + dense_mass=[("x",), ("y",), ("z")]: use dense mass matrices for
                  each of x, y, and z (i.e. block-diagonal with 3 blocks)

        :type dense_mass: bool or list
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Defaults to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of
            integers `(d1, d2)`, where `d1` is the max tree depth during warmup phase and
            `d2` is the max tree depth during post warmup phase.
        :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
            step size at the beginning of each adaptation window. Defaults to False.
        :param bool regularize_mass_matrix: whether or not to regularize the estimated mass
            matrix for numerical stability during warmup phase. Defaults to True. This flag
            does not take effect if ``adapt_mass_matrix == False``.
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.

        """
        step_size = lax.convert_element_type(step_size, jnp.result_type(float))
        if trajectory_length is not None:
            trajectory_length = lax.convert_element_type(
                trajectory_length, jnp.result_type(float))
        nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad
        forward_mode_ad = forward_mode_differentiation
        wa_steps = num_warmup
        max_treedepth = (max_tree_depth if isinstance(max_tree_depth, tuple)
                         else (max_tree_depth, max_tree_depth))
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe_fn = potential_fn
        if potential_fn_gen:
            if pe_fn is not None:
                raise ValueError(
                    "Only one of `potential_fn` or `potential_fn_gen` must be provided."
                )
            else:
                kwargs = {} if model_kwargs is None else model_kwargs
                pe_fn = potential_fn_gen(*model_args, **kwargs)

        find_reasonable_ss = None
        if find_heuristic_step_size:
            find_reasonable_ss = partial(find_reasonable_step_size, pe_fn,
                                         kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss,
            regularize_mass_matrix=regularize_mass_matrix,
        )

        rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
        z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
        wa_state = wa_init(z_info,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix)
        r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
        vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn,
                                             forward_mode_ad)
        vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
        energy = vv_state.potential_energy + kinetic_fn(
            wa_state.inverse_mass_matrix, vv_state.r)
        zero_int = jnp.array(0, dtype=jnp.result_type(int))
        hmc_state = HMCState(
            zero_int,
            vv_state.z,
            vv_state.z_grad,
            vv_state.potential_energy,
            energy,
            None,
            trajectory_length,
            zero_int,
            jnp.zeros(()),
            jnp.zeros(()),
            jnp.array(False),
            wa_state,
            rng_key_hmc,
        )
        return device_put(hmc_state)
Beispiel #8
0
def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
    r"""
    Hamiltonian Monte Carlo inference, using either fixed number of
    steps or the No U-Turn Sampler (NUTS) with adaptive path length.

    **References:**

    1. *MCMC Using Hamiltonian Dynamics*,
       Radford M. Neal
    2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, and Andrew Gelman.
    3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*,
       Michael Betancourt

    :param potential_fn: Python callable that computes the potential energy
        given input parameters. The input parameters to `potential_fn` can be
        any python collection type, provided that `init_params` argument to
        `init_kernel` has the same type.
    :param kinetic_fn: Python callable that returns the kinetic energy given
        inverse mass matrix and momentum. If not provided, the default is
        euclidean kinetic energy.
    :param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS``
        with adaptive path length. Default is ``NUTS``.
    :return: a tuple of callables (`init_kernel`, `sample_kernel`), the first
        one to initialize the sampler, and the second one to generate samples
        given an existing one.

    .. warning::
        Instead of using this interface directly, we would highly recommend you
        to use the higher level :class:`numpyro.infer.MCMC` API instead.

    **Example**

    .. testsetup::

        import jax
        from jax import random
        import jax.numpy as np
        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer.mcmc import hmc
        from numpyro.infer.util import initialize_model
        from numpyro.util import fori_collect

    .. doctest::

        >>> true_coefs = np.array([1., 2., 3.])
        >>> data = random.normal(random.PRNGKey(2), (2000, 3))
        >>> dim = 3
        >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
        >>>
        >>> def model(data, labels):
        ...     coefs_mean = np.zeros(dim)
        ...     coefs = numpyro.sample('beta', dist.Normal(coefs_mean, np.ones(3)))
        ...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
        ...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
        >>>
        >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0),
        ...                                                            model, data, labels)
        >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
        >>> hmc_state = init_kernel(init_params,
        ...                         trajectory_length=10,
        ...                         num_warmup=300)
        >>> samples = fori_collect(0, 500, sample_kernel, hmc_state,
        ...                        transform=lambda state: constrain_fn(state.z))
        >>> print(np.mean(samples['beta'], axis=0))  # doctest: +SKIP
        [0.9153987 2.0754058 2.9621222]
    """
    if kinetic_fn is None:
        kinetic_fn = euclidean_kinetic_energy
    vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    trajectory_len = None
    max_treedepth = None
    momentum_generator = None
    wa_update = None
    wa_steps = None
    max_delta_energy = 1000.
    if algo not in {'HMC', 'NUTS'}:
        raise ValueError('`algo` must be one of `HMC` or `NUTS`.')

    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng_key=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`~numpyro.infer.mcmc.HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.
        """
        step_size = lax.convert_element_type(
            step_size, xla_bridge.canonicalize_dtype(np.float64))
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa = random.split(rng_key)
        wa_state = wa_init(z,
                           rng_key_wa,
                           step_size,
                           mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key)
        vv_state = vv_init(z, r)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)

        # TODO: Remove; this should be the responsibility of the MCMC class.
        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = fori_loop(0, num_warmup,
                                      lambda *args: sample_kernel(args[1]),
                                      hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = jit(sample_kernel)(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state),
                                          refresh=False)
        return hmc_state

    def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng_key):
        num_steps = _get_num_steps(step_size, trajectory_len)
        vv_state_new = fori_loop(
            0, num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
        accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition, (vv_state_new, energy_new),
                                lambda args: args, (vv_state, energy_old),
                                lambda args: args)
        return vv_state, energy, num_steps, accept_prob, diverging

    def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng_key):
        binary_tree = build_tree(vv_update,
                                 kinetic_fn,
                                 vv_state,
                                 inverse_mass_matrix,
                                 step_size,
                                 rng_key,
                                 max_delta_energy=max_delta_energy,
                                 max_tree_depth=max_treedepth)
        accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
        num_steps = binary_tree.num_proposals
        vv_state = IntegratorState(z=binary_tree.z_proposal,
                                   r=vv_state.r,
                                   potential_energy=binary_tree.z_proposal_pe,
                                   z_grad=binary_tree.z_proposal_grad)
        return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging

    _next = _nuts_next if algo == 'NUTS' else _hmc_next

    def sample_kernel(hmc_state):
        """
        Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted)
        step size and return a new :data:`~numpyro.infer.mcmc.HMCState`.

        :param hmc_state: Current sample (and associated state).
        :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating
            Hamiltonian dynamics given existing state.
        """
        rng_key, rng_key_momentum, rng_key_transition = random.split(
            hmc_state.rng_key, 3)
        r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_key_momentum)
        vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy,
                                   hmc_state.z_grad)
        vv_state, energy, num_steps, accept_prob, diverging = _next(
            hmc_state.adapt_state.step_size,
            hmc_state.adapt_state.inverse_mass_matrix, vv_state,
            rng_key_transition)
        # not update adapt_state after warmup phase
        adapt_state = cond(
            hmc_state.i < wa_steps,
            (hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state),
            lambda args: wa_update(*args), hmc_state.adapt_state, lambda x: x)

        itr = hmc_state.i + 1
        n = np.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = hmc_state.mean_accept_prob + (
            accept_prob - hmc_state.mean_accept_prob) / n

        return HMCState(itr, vv_state.z, vv_state.z_grad,
                        vv_state.potential_energy, energy, num_steps,
                        accept_prob, mean_accept_prob, diverging, adapt_state,
                        rng_key)

    # Make `init_kernel` and `sample_kernel` visible from the global scope once
    # `hmc` is called for sphinx doc generation.
    if 'SPHINX_BUILD' in os.environ:
        hmc.init_kernel = init_kernel
        hmc.sample_kernel = sample_kernel

    return init_kernel, sample_kernel