Ejemplo n.º 1
0
    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        rng_key, rng_r = random.split(rng_key)
        state = super().init(rng_key, num_warmup, init_params, model_args,
                             model_kwargs)
        self._support_sizes_flat, _ = ravel_pytree(
            {k: self._support_sizes[k]
             for k in self._gibbs_sites})
        if self._num_discrete_updates is None:
            self._num_discrete_updates = self._support_sizes_flat.shape[0]
        self._num_warmup = num_warmup

        # NB: the warmup adaptation can not be performed in sub-trajectories (i.e. the hmc trajectory
        # between two discrete updates), so we will do it here, at the end of each MixedHMC step.
        _, self._wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=self.inner_kernel._adapt_step_size,
            adapt_mass_matrix=self.inner_kernel._adapt_mass_matrix,
            dense_mass=self.inner_kernel._dense_mass,
            target_accept_prob=self.inner_kernel._target_accept_prob,
            find_reasonable_step_size=None,
        )

        # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momemtum at the
        # beginning of an HMC step. The reason is we need to maintain `r` between each sub-trajectories.
        r = momentum_generator(state.hmc_state.z,
                               state.hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_r)
        return MixedHMCState(state.z, state.hmc_state._replace(r=r),
                             state.rng_key, jnp.zeros(()))
Ejemplo n.º 2
0
    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        self._num_warmup = num_warmup
        # TODO (low-priority): support chain_method="vectorized", i.e. rng_key is a batch of keys
        assert rng_key.shape == (2,), ("BarkerMH only supports chain_method='parallel' or chain_method='sequential'."
                                       " Please put in a feature request if you think it would be useful to be able "
                                       "to use BarkerMH in vectorized mode.")
        rng_key, rng_key_init_model, rng_key_wa = random.split(rng_key, 3)
        init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
        if self._potential_fn and init_params is None:
            raise ValueError('Valid value of `init_params` must be provided with'
                             ' `potential_fn`.')

        pe, grad = jax.value_and_grad(self._potential_fn)(init_params)

        wa_init, self._wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=self._adapt_step_size,
            adapt_mass_matrix=self._adapt_mass_matrix,
            dense_mass=self._dense_mass,
            target_accept_prob=self._target_accept_prob)
        size = len(ravel_pytree(init_params)[0])
        wa_state = wa_init(None, rng_key_wa, self._step_size, mass_matrix_size=size)
        wa_state = wa_state._replace(rng_key=None)
        init_state = BarkerMHState(jnp.array(0), init_params, pe, grad, jnp.array(0.),
                                   jnp.array(0.), wa_state, rng_key)
        return jax.device_put(init_state)
Ejemplo n.º 3
0
def test_warmup_adapter(jitted):
    def find_reasonable_step_size(step_size, m_inv, z, rng_key):
        return jnp.where(step_size < 1, step_size * 4, step_size / 4)

    num_steps = 150
    adaptation_schedule = build_adaptation_schedule(num_steps)
    init_step_size = 1.
    mass_matrix_size = 3

    wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
    wa_update = jit(wa_update) if jitted else wa_update

    rng_key = random.PRNGKey(0)
    z = jnp.ones(3)
    wa_state = wa_init((z, None, None, None),
                       rng_key,
                       init_step_size,
                       mass_matrix_size=mass_matrix_size)
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert step_size == find_reasonable_step_size(init_step_size,
                                                  inverse_mass_matrix, z,
                                                  rng_key)
    assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))
    assert window_idx == 0

    window = adaptation_schedule[0]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z,
                             wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 1
    # step_size is decreased because accept_prob < target_accept_prob
    assert step_size < last_step_size
    # inverse_mass_matrix does not change at the end of the first window
    assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))

    window = adaptation_schedule[1]
    window_len = window.end - window.start
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len,
                             2 * z, wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 2
    # step_size is increased because accept_prob > target_accept_prob
    assert step_size > last_step_size
    # Verifies that inverse_mass_matrix changes at the end of the second window.
    # Because z_flat is constant during the second window, covariance will be 0
    # and only regularize_term of welford scheme is involved.
    # This also verifies that z_flat terms in the first window does not affect
    # the second window.
    welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5))
    assert_allclose(inverse_mass_matrix,
                    jnp.full((mass_matrix_size, ), welford_regularize_term),
                    atol=1e-7)

    window = adaptation_schedule[2]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8, t * z, wa_state)
    last_step_size = step_size
    step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 3
    # during the last window, because target_accept_prob=0.8,
    # log_step_size will be equal to the constant prox_center=log(10*last_step_size)
    assert_allclose(step_size, last_step_size * 10, atol=1e-6)
    # Verifies that inverse_mass_matrix does not change during the last window
    # despite z_flat changes w.r.t time t,
    assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)
Ejemplo n.º 4
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    inverse_mass_matrix=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    find_heuristic_step_size=False,
                    model_args=(),
                    model_kwargs=None,
                    rng_key=random.PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
            step size at the beginning of each adaptation window. Defaults to False.
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.

        """
        step_size = lax.convert_element_type(step_size,
                                             canonicalize_dtype(jnp.float64))
        nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe_fn = potential_fn
        if potential_fn_gen:
            if pe_fn is not None:
                raise ValueError(
                    'Only one of `potential_fn` or `potential_fn_gen` must be provided.'
                )
            else:
                kwargs = {} if model_kwargs is None else model_kwargs
                pe_fn = potential_fn_gen(*model_args, **kwargs)

        find_reasonable_ss = None
        if find_heuristic_step_size:
            find_reasonable_ss = partial(find_reasonable_step_size, pe_fn,
                                         kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
        z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
        wa_state = wa_init(z_info,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix,
                           mass_matrix_size=jnp.size(ravel_pytree(z)[0]))
        r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
        vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn)
        vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)
        return device_put(hmc_state)
Ejemplo n.º 5
0
    def init_kernel(
            init_params,
            num_warmup,
            *,
            step_size=1.0,
            inverse_mass_matrix=None,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
            target_accept_prob=0.8,
            trajectory_length=2 * math.pi,
            max_tree_depth=10,
            find_heuristic_step_size=False,
            forward_mode_differentiation=False,
            regularize_mass_matrix=True,
            model_args=(),
            model_kwargs=None,
            rng_key=random.PRNGKey(0),
    ):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
            For a potential_fn with general JAX pytree parameters, the order of entries
            of the mass matrix is the order of the flattened version of pytree parameters
            obtained with `jax.tree_flatten`, which is a bit ambiguous (see more at
            https://jax.readthedocs.io/en/latest/pytrees.html). If `model` is not None,
            here we can specify a structured block mass matrix as a dictionary, where
            keys are tuple of site names and values are the corresponding block of the
            mass matrix.
            For more information about structured mass matrix, see `dense_mass` argument.
        :type inverse_mass_matrix: numpy.ndarray or dict
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param dense_mass:  This flag controls whether mass matrix is dense (i.e. full-rank) or
            diagonal (defaults to ``dense_mass=False``). To specify a structured mass matrix,
            users can provide a list of tuples of site names. Each tuple represents
            a block in the joint mass matrix. For example, assuming that the model
            has latent variables "x", "y", "z" (where each variable can be multi-dimensional),
            possible specifications and corresponding mass matrix structures are as follows:

                + dense_mass=[("x", "y")]: use a dense mass matrix for the joint
                  (x, y) and a diagonal mass matrix for z
                + dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass
                  matrix for the joint (x, y, z)
                + dense_mass=[("x", "y", "z")] (equivalent to full_mass=True):
                  use a dense mass matrix for the joint (x, y, z)
                + dense_mass=[("x",), ("y",), ("z")]: use dense mass matrices for
                  each of x, y, and z (i.e. block-diagonal with 3 blocks)

        :type dense_mass: bool or list
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Defaults to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of
            integers `(d1, d2)`, where `d1` is the max tree depth during warmup phase and
            `d2` is the max tree depth during post warmup phase.
        :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
            step size at the beginning of each adaptation window. Defaults to False.
        :param bool regularize_mass_matrix: whether or not to regularize the estimated mass
            matrix for numerical stability during warmup phase. Defaults to True. This flag
            does not take effect if ``adapt_mass_matrix == False``.
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.

        """
        step_size = lax.convert_element_type(step_size, jnp.result_type(float))
        if trajectory_length is not None:
            trajectory_length = lax.convert_element_type(
                trajectory_length, jnp.result_type(float))
        nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad
        forward_mode_ad = forward_mode_differentiation
        wa_steps = num_warmup
        max_treedepth = (max_tree_depth if isinstance(max_tree_depth, tuple)
                         else (max_tree_depth, max_tree_depth))
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe_fn = potential_fn
        if potential_fn_gen:
            if pe_fn is not None:
                raise ValueError(
                    "Only one of `potential_fn` or `potential_fn_gen` must be provided."
                )
            else:
                kwargs = {} if model_kwargs is None else model_kwargs
                pe_fn = potential_fn_gen(*model_args, **kwargs)

        find_reasonable_ss = None
        if find_heuristic_step_size:
            find_reasonable_ss = partial(find_reasonable_step_size, pe_fn,
                                         kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss,
            regularize_mass_matrix=regularize_mass_matrix,
        )

        rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
        z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
        wa_state = wa_init(z_info,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix)
        r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
        vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn,
                                             forward_mode_ad)
        vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
        energy = vv_state.potential_energy + kinetic_fn(
            wa_state.inverse_mass_matrix, vv_state.r)
        zero_int = jnp.array(0, dtype=jnp.result_type(int))
        hmc_state = HMCState(
            zero_int,
            vv_state.z,
            vv_state.z_grad,
            vv_state.potential_energy,
            energy,
            None,
            trajectory_length,
            zero_int,
            jnp.zeros(()),
            jnp.zeros(()),
            jnp.array(False),
            wa_state,
            rng_key_hmc,
        )
        return device_put(hmc_state)
Ejemplo n.º 6
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng_key=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`~numpyro.infer.mcmc.HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.
        """
        step_size = lax.convert_element_type(
            step_size, xla_bridge.canonicalize_dtype(np.float64))
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa = random.split(rng_key)
        wa_state = wa_init(z,
                           rng_key_wa,
                           step_size,
                           mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key)
        vv_state = vv_init(z, r)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)

        # TODO: Remove; this should be the responsibility of the MCMC class.
        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = fori_loop(0, num_warmup,
                                      lambda *args: sample_kernel(args[1]),
                                      hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = jit(sample_kernel)(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state),
                                          refresh=False)
        return hmc_state
Ejemplo n.º 7
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    inverse_mass_matrix=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    rng_key=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.
        """
        step_size = lax.convert_element_type(
            step_size, xla_bridge.canonicalize_dtype(np.float64))
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa = random.split(rng_key)
        wa_state = wa_init(z,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix,
                           mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key)
        vv_state = vv_init(z, r)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)
        return hmc_state