Exemple #1
0
 def body_fn(i, val):
     idx = idxs[i]
     support_size = support_sizes_flat[idx]
     rng_key, z, pe = val
     rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
         rng_key,
         z,
         pe,
         potential_fn=potential_fn,
         idx=idx,
         support_size=support_size)
     rng_key, rng_accept = random.split(rng_key)
     # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
     # and -log(u) ~ exponential(1)
     z, pe = cond(
         random.exponential(rng_accept) > -log_accept_ratio,
         (z_new, pe_new), identity, (z, pe), identity)
     return rng_key, z, pe
Exemple #2
0
def _discrete_gibbs_proposal_body_fn(z_init_flat, unravel_fn, pe_init,
                                     potential_fn, idx, i, val):
    rng_key, z, pe, log_weight_sum = val
    rng_key, rng_transition = random.split(rng_key)
    proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
    z_new_flat = ops.index_update(z_init_flat, idx, proposal)
    z_new = unravel_fn(z_new_flat)
    pe_new = potential_fn(z_new)
    log_weight_new = pe_init - pe_new
    # Handles the NaN case...
    log_weight_new = jnp.where(jnp.isfinite(log_weight_new), log_weight_new,
                               -jnp.inf)
    # transition_prob = e^weight_new / (e^weight_logsumexp + e^weight_new)
    transition_prob = expit(log_weight_new - log_weight_sum)
    z, pe = cond(random.bernoulli(rng_transition, transition_prob),
                 (z_new, pe_new), identity, (z, pe), identity)
    log_weight_sum = jnp.logaddexp(log_weight_new, log_weight_sum)
    return rng_key, z, pe, log_weight_sum
Exemple #3
0
    def sample(self, state, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
        rng_key, rng_gibbs = random.split(state.rng_key)

        def potential_fn(z_gibbs, gibbs_state, z_hmc):
            return self.inner_kernel._potential_fn_gen(
                *model_args,
                _gibbs_sites=z_gibbs,
                _gibbs_state=gibbs_state,
                **model_kwargs,
            )(z_hmc)

        z_gibbs = {
            k: v
            for k, v in state.z.items() if k not in state.hmc_state.z
        }
        z_gibbs_new, gibbs_state_new = self._gibbs_update(
            rng_key, z_gibbs, state.gibbs_state)

        # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr
        pe = state.hmc_state.potential_energy
        pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z)
        accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0)
        transition = random.bernoulli(rng_key, accept_prob)
        grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad
        z_gibbs, gibbs_state, pe, z_grad = cond(
            transition,
            (z_gibbs_new, gibbs_state_new, pe_new),
            lambda vals: vals + (grad_(partial(potential_fn, vals[0], vals[1]))
                                 (state.hmc_state.z), ),
            (z_gibbs, state.gibbs_state, pe, state.hmc_state.z_grad),
            identity,
        )

        hmc_state = state.hmc_state._replace(z_grad=z_grad,
                                             potential_energy=pe)

        model_kwargs["_gibbs_sites"] = z_gibbs
        model_kwargs["_gibbs_state"] = gibbs_state
        hmc_state = self.inner_kernel.sample(hmc_state, model_args,
                                             model_kwargs)

        z = {**z_gibbs, **hmc_state.z}
        return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob)
Exemple #4
0
 def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng_key):
     num_steps = _get_num_steps(step_size, trajectory_len)
     vv_state_new = fori_loop(
         0, num_steps,
         lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
         vv_state)
     energy_old = vv_state.potential_energy + kinetic_fn(
         inverse_mass_matrix, vv_state.r)
     energy_new = vv_state_new.potential_energy + kinetic_fn(
         inverse_mass_matrix, vv_state_new.r)
     delta_energy = energy_new - energy_old
     delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
     accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
     diverging = delta_energy > max_delta_energy
     transition = random.bernoulli(rng_key, accept_prob)
     vv_state, energy = cond(transition, (vv_state_new, energy_new),
                             lambda args: args, (vv_state, energy_old),
                             lambda args: args)
     return vv_state, energy, num_steps, accept_prob, diverging
Exemple #5
0
    def sample(self, state, model_args, model_kwargs):
        """
        Given the current `state`, return the next `state` using the given
        transition kernel.

        :param state: A `pytree <https://jax.readthedocs.io/en/latest/pytrees.html>`_
            class representing the state for the kernel. For HMC, this is given
            by :data:`~numpyro.infer.hmc.HMCStkernel0ate`. In general, this could be any
            class that supports `getattr`.
        :param model_args: Arguments provided to the model.
        :param model_kwargs: Keyword arguments provided to the model.
        :return: Next `state`.
        """
        return cond(state.itr % 2 == 0,
                    state,
                    lambda s: self._kernel0.sample(
                        s, model_args, model_kwargs),
                    state,
                    lambda s: self._kernel1.sample(s, model_args, model_kwargs))
Exemple #6
0
    def sample_kernel(hmc_state, model_args=(), model_kwargs=None):
        """
        Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted)
        step size and return a new :data:`~numpyro.infer.mcmc.HMCState`.

        :param hmc_state: Current sample (and associated state).
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating
            Hamiltonian dynamics given existing state.

        """
        model_kwargs = {} if model_kwargs is None else model_kwargs
        rng_key, rng_key_momentum, rng_key_transition = random.split(
            hmc_state.rng_key, 3)
        r = momentum_generator(hmc_state.z,
                               hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_key_momentum)
        vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy,
                                   hmc_state.z_grad)
        vv_state, energy, num_steps, accept_prob, diverging = _next(
            hmc_state.adapt_state.step_size,
            hmc_state.adapt_state.inverse_mass_matrix, vv_state, model_args,
            model_kwargs, rng_key_transition)
        # not update adapt_state after warmup phase
        adapt_state = cond(
            hmc_state.i < wa_steps,
            (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
            lambda args: wa_update(*args), hmc_state.adapt_state, identity)

        itr = hmc_state.i + 1
        n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = hmc_state.mean_accept_prob + (
            accept_prob - hmc_state.mean_accept_prob) / n

        return HMCState(itr, vv_state.z, vv_state.z_grad,
                        vv_state.potential_energy, energy, num_steps,
                        accept_prob, mean_accept_prob, diverging, adapt_state,
                        rng_key)
Exemple #7
0
    def _hmc_next(step_size, inverse_mass_matrix, vv_state,
                  model_args, model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn)

        num_steps = _get_num_steps(step_size, trajectory_len)
        vv_state_new = fori_loop(0, num_steps,
                                 lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
                                 vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition,
                                (vv_state_new, energy_new), identity,
                                (vv_state, energy_old), identity)
        return vv_state, energy, num_steps, accept_prob, diverging
Exemple #8
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        assert set(gibbs_sites) == set(plate_sizes)
        u_new = {}
        for name in gibbs_sites:
            size, subsample_size = plate_sizes[name]
            rng_key, subkey, block_key = random.split(rng_key, 3)
            block_size = subsample_size // num_blocks

            chosen_block = random.randint(block_key,
                                          shape=(),
                                          minval=0,
                                          maxval=num_blocks)
            new_idx = random.randint(subkey,
                                     minval=0,
                                     maxval=size,
                                     shape=(subsample_size, ))
            block_mask = jnp.arange(
                subsample_size) // block_size == chosen_block

            u_new[name] = jnp.where(block_mask, new_idx, gibbs_sites[name])

        u_loglik = log_likelihood(_wrap_model(model),
                                  hmc_sites,
                                  *model_args,
                                  batch_ndims=0,
                                  **model_kwargs,
                                  _gibbs_sites=gibbs_sites)
        u_loglik = sum(v.sum() for v in u_loglik.values())
        u_new_loglik = log_likelihood(_wrap_model(model),
                                      hmc_sites,
                                      *model_args,
                                      batch_ndims=0,
                                      **model_kwargs,
                                      _gibbs_sites=u_new)
        u_new_loglik = sum(v.sum() for v in u_new_loglik.values())
        accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0)
        return cond(random.bernoulli(rng_key, accept_prob), u_new, identity,
                    gibbs_sites, identity)
Exemple #9
0
    def _body_fn(state):
        current_tree, _, r_ckpts, r_sum_ckpts, rng = state
        rng, transition_rng = random.split(rng)
        z, r, z_grad = _get_leaf(current_tree, going_right)
        new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size,
                                   going_right, energy_current, max_delta_energy)
        new_tree = _combine_tree(current_tree, new_leaf, inverse_mass_matrix, going_right,
                                 transition_rng, False)

        leaf_idx = current_tree.num_proposals
        ckpt_idx_min, ckpt_idx_max = _leaf_idx_to_ckpt_idxs(leaf_idx)
        r, _ = ravel_pytree(new_leaf.r_right)
        r_sum, _ = ravel_pytree(new_tree.r_sum)
        # we update checkpoints when leaf_idx is even
        r_ckpts, r_sum_ckpts = cond(leaf_idx % 2 == 0,
                                    (r_ckpts, r_sum_ckpts),
                                    lambda x: (index_update(x[0], ckpt_idx_max, r),
                                               index_update(x[1], ckpt_idx_max, r_sum)),
                                    (r_ckpts, r_sum_ckpts),
                                    lambda x: x)

        turning = _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts,
                                        ckpt_idx_min, ckpt_idx_max)
        return new_tree, turning, r_ckpts, r_sum_ckpts, rng
Exemple #10
0
def _discrete_modified_gibbs_proposal(rng_key,
                                      z_discrete,
                                      pe,
                                      potential_fn,
                                      idx,
                                      support_size,
                                      stay_prob=0.):
    assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
    body_fn = partial(_discrete_gibbs_proposal_body_fn, z_discrete_flat,
                      unravel_fn, pe, potential_fn, idx)
    # like gibbs_step but here, weight of the current value is 0
    init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
    rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1,
                                                       body_fn, init_val)
    rng_key, rng_stay = random.split(rng_key)
    z_new, pe_new = cond(random.bernoulli(rng_stay, stay_prob),
                         (z_discrete, pe), identity, (z_new, pe_new), identity)
    # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new))
    # where 1 - P(z) ~ weight_sum
    # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight
    log_accept_ratio = log_weight_sum - jnp.log(
        jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new))
    return rng_key, z_new, pe_new, log_accept_ratio
Exemple #11
0
def _get_leaf(tree, going_right):
    return cond(going_right, tree, lambda tree:
                (tree.z_right, tree.r_right, tree.z_right_grad), tree,
                lambda tree: (tree.z_left, tree.r_left, tree.z_left_grad))
Exemple #12
0
    def sample(self, state, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs
        num_discretes = self._support_sizes_flat.shape[0]

        def potential_fn(z_gibbs, z_hmc):
            return self.inner_kernel._potential_fn_gen(*model_args,
                                                       _gibbs_sites=z_gibbs,
                                                       **model_kwargs)(z_hmc)

        def update_discrete(idx, rng_key, hmc_state, z_discrete, ke_discrete,
                            delta_pe_sum):
            # Algo 1, line 19: get a new discrete proposal
            (
                rng_key,
                z_discrete_new,
                pe_new,
                log_accept_ratio,
            ) = self._discrete_proposal_fn(
                rng_key,
                z_discrete,
                hmc_state.potential_energy,
                partial(potential_fn, z_hmc=hmc_state.z),
                idx,
                self._support_sizes_flat[idx],
            )
            # Algo 1, line 20: depending on reject or refract, we will update
            # the discrete variable and its corresponding kinetic energy. In case of
            # refract, we will need to update the potential energy and its grad w.r.t. hmc_state.z
            ke_discrete_i_new = ke_discrete[idx] + log_accept_ratio
            grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad
            z_discrete, pe, ke_discrete_i, z_grad = lax.cond(
                ke_discrete_i_new > 0,
                (z_discrete_new, pe_new, ke_discrete_i_new),
                lambda vals: vals + (grad_(partial(potential_fn, vals[0]))
                                     (hmc_state.z), ),
                (
                    z_discrete,
                    hmc_state.potential_energy,
                    ke_discrete[idx],
                    hmc_state.z_grad,
                ),
                identity,
            )

            delta_pe_sum = delta_pe_sum + pe - hmc_state.potential_energy
            ke_discrete = ops.index_update(ke_discrete, idx, ke_discrete_i)
            hmc_state = hmc_state._replace(potential_energy=pe, z_grad=z_grad)
            return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum

        def update_continuous(hmc_state, z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            hmc_state_new = self.inner_kernel.sample(hmc_state, model_args,
                                                     model_kwargs_)

            # each time a sub-trajectory is performed, we need to reset i and adapt_state
            # (we will only update them at the end of HMCGibbs step)
            # For `num_steps`, we will record its cumulative sum for diagnostics
            hmc_state = hmc_state_new._replace(
                i=hmc_state.i,
                adapt_state=hmc_state.adapt_state,
                num_steps=hmc_state.num_steps + hmc_state_new.num_steps,
            )
            return hmc_state

        def body_fn(i, vals):
            (
                rng_key,
                hmc_state,
                z_discrete,
                ke_discrete,
                delta_pe_sum,
                arrival_times,
            ) = vals
            idx = jnp.argmin(arrival_times)
            # NB: length of each sub-trajectory is scaled from the current min(arrival_times)
            # (see the note at total_time below)
            trajectory_length = arrival_times[idx] * time_unit
            arrival_times = arrival_times - arrival_times[idx]
            arrival_times = ops.index_update(arrival_times, idx, 1.0)

            # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal
            pe = jnp.inf
            hmc_state = hmc_state._replace(trajectory_length=trajectory_length,
                                           potential_energy=pe)
            # Algo 1, line 7: perform a sub-trajectory
            hmc_state = update_continuous(hmc_state, z_discrete)
            # Algo 1, line 8: perform a discrete update
            rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete(
                idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum)
            return (
                rng_key,
                hmc_state,
                z_discrete,
                ke_discrete,
                delta_pe_sum,
                arrival_times,
            )

        z_discrete = {
            k: v
            for k, v in state.z.items() if k not in state.hmc_state.z
        }
        rng_key, rng_ke, rng_time, rng_r, rng_accept = random.split(
            state.rng_key, 5)
        # Algo 1, line 2: sample discrete kinetic energy
        ke_discrete = random.exponential(rng_ke, (num_discretes, ))
        # Algo 1, line 4 and 5: sample the initial amount of time that each discrete site visits
        # the point 0/1. The logic in GetStepSizesNSteps(...) is more complicated but does
        # the same job: the sub-trajectory length eta_t * M_t is the lag between two arrival time.
        arrival_times = random.uniform(rng_time, (num_discretes, ))
        # compute the amount of time to make `num_discrete_updates` discrete updates
        total_time = (self._num_discrete_updates -
                      1) // num_discretes + jnp.sort(arrival_times)[
                          (self._num_discrete_updates - 1) % num_discretes]
        # NB: total_time can be different from the HMC trajectory length, so we need to scale
        # the time unit so that total_time * time_unit = hmc_trajectory_length
        time_unit = state.hmc_state.trajectory_length / total_time

        # Algo 1, line 2: sample hmc momentum
        r = momentum_generator(state.hmc_state.r,
                               state.hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_r)
        hmc_state = state.hmc_state._replace(r=r, num_steps=0)
        hmc_ke = euclidean_kinetic_energy(
            hmc_state.adapt_state.inverse_mass_matrix, r)
        # Algo 1, line 10: compute the initial energy
        energy_old = hmc_ke + hmc_state.potential_energy

        # Algo 1, line 3: set initial values
        delta_pe_sum = 0.0
        init_val = (
            rng_key,
            hmc_state,
            z_discrete,
            ke_discrete,
            delta_pe_sum,
            arrival_times,
        )
        # Algo 1, line 6-9: perform the update loop
        rng_key, hmc_state_new, z_discrete_new, _, delta_pe_sum, _ = fori_loop(
            0, self._num_discrete_updates, body_fn, init_val)
        # Algo 1, line 10: compute the proposal energy
        hmc_ke = euclidean_kinetic_energy(
            hmc_state.adapt_state.inverse_mass_matrix, hmc_state_new.r)
        energy_new = hmc_ke + hmc_state_new.potential_energy
        # Algo 1, line 11: perform MH correction
        delta_energy = energy_new - energy_old - delta_pe_sum
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)

        # record the correct new num_steps
        hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps)
        # reset the trajectory length
        hmc_state_new = hmc_state_new._replace(
            trajectory_length=hmc_state.trajectory_length)
        hmc_state, z_discrete = cond(
            random.bernoulli(rng_key, accept_prob),
            (hmc_state_new, z_discrete_new),
            identity,
            (hmc_state, z_discrete),
            identity,
        )

        # perform hmc adapting (similar to the implementation in hmc)
        adapt_state = cond(
            hmc_state.i < self._num_warmup,
            (hmc_state.i, accept_prob, (hmc_state.z, ), hmc_state.adapt_state),
            lambda args: self._wa_update(*args),
            hmc_state.adapt_state,
            identity,
        )

        itr = hmc_state.i + 1
        n = jnp.where(hmc_state.i < self._num_warmup, itr,
                      itr - self._num_warmup)
        mean_accept_prob_prev = state.hmc_state.mean_accept_prob
        mean_accept_prob = (mean_accept_prob_prev +
                            (accept_prob - mean_accept_prob_prev) / n)
        hmc_state = hmc_state._replace(
            i=itr,
            accept_prob=accept_prob,
            mean_accept_prob=mean_accept_prob,
            adapt_state=adapt_state,
        )

        z = {**z_discrete, **hmc_state.z}
        return MixedHMCState(z, hmc_state, rng_key, accept_prob)
Exemple #13
0
def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right,
                  rng_key, biased_transition):
    # Now we combine the current tree and the new tree. Note that outside
    # leaves of the combined tree are determined by the direction.
    z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
        going_right,
        (current_tree, new_tree),
        lambda trees: (
            trees[0].z_left,
            trees[0].r_left,
            trees[0].z_left_grad,
            trees[1].z_right,
            trees[1].r_right,
            trees[1].z_right_grad,
        ),
        (new_tree, current_tree),
        lambda trees: (
            trees[0].z_left,
            trees[0].r_left,
            trees[0].z_left_grad,
            trees[1].z_right,
            trees[1].r_right,
            trees[1].z_right_grad,
        ),
    )
    r_sum = tree_multimap(jnp.add, current_tree.r_sum, new_tree.r_sum)

    if biased_transition:
        transition_prob = _biased_transition_kernel(current_tree, new_tree)
        turning = new_tree.turning | _is_turning(inverse_mass_matrix, r_left,
                                                 r_right, r_sum)
    else:
        transition_prob = _uniform_transition_kernel(current_tree, new_tree)
        turning = current_tree.turning

    transition = random.bernoulli(rng_key, transition_prob)
    z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond(
        transition,
        new_tree,
        lambda tree: (
            tree.z_proposal,
            tree.z_proposal_pe,
            tree.z_proposal_grad,
            tree.z_proposal_energy,
        ),
        current_tree,
        lambda tree: (
            tree.z_proposal,
            tree.z_proposal_pe,
            tree.z_proposal_grad,
            tree.z_proposal_energy,
        ),
    )

    tree_depth = current_tree.depth + 1
    tree_weight = jnp.logaddexp(current_tree.weight, new_tree.weight)
    diverging = new_tree.diverging

    sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs
    num_proposals = current_tree.num_proposals + new_tree.num_proposals

    return TreeInfo(
        z_left,
        r_left,
        z_left_grad,
        z_right,
        r_right,
        r_right_grad,
        z_proposal,
        z_proposal_pe,
        z_proposal_grad,
        z_proposal_energy,
        tree_depth,
        tree_weight,
        r_sum,
        turning,
        diverging,
        sum_accept_probs,
        num_proposals,
    )
Exemple #14
0
    def update_fn(t, accept_prob, z_info, state):
        """
        :param int t: The current time step.
        :param float accept_prob: Acceptance probability of the current trajectory.
        :param IntegratorState z_info: The new integrator state.
        :param state: Current state of the adapt scheme.
        :return: new state of the adapt scheme.
        """
        (
            step_size,
            inverse_mass_matrix,
            mass_matrix_sqrt,
            mass_matrix_sqrt_inv,
            ss_state,
            mm_state,
            window_idx,
            rng_key,
        ) = state
        if rng_key is not None:
            rng_key, rng_key_ss = random.split(rng_key)
        else:
            rng_key_ss = None

        # update step size state
        if adapt_step_size:
            ss_state = ss_update(target_accept_prob - accept_prob, ss_state)
            # note: at the end of warmup phase, use average of log step_size
            log_step_size, log_step_size_avg, *_ = ss_state
            step_size = jnp.where(
                t == (num_adapt_steps - 1),
                jnp.exp(log_step_size_avg),
                jnp.exp(log_step_size),
            )
            # account the the case log_step_size is an extreme number
            finfo = jnp.finfo(jnp.result_type(step_size))
            step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max)

        # update mass matrix state
        is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1))
        if adapt_mass_matrix:
            z = z_info[0]
            mm_state = cond(
                is_middle_window,
                (z, mm_state),
                lambda args: mm_update(*args),
                mm_state,
                identity,
            )

        t_at_window_end = t == adaptation_schedule[window_idx, 1]
        window_idx = jnp.where(t_at_window_end, window_idx + 1, window_idx)
        state = HMCAdaptState(
            step_size,
            inverse_mass_matrix,
            mass_matrix_sqrt,
            mass_matrix_sqrt_inv,
            ss_state,
            mm_state,
            window_idx,
            rng_key,
        )
        state = cond(
            t_at_window_end & is_middle_window,
            (z_info, rng_key_ss, state),
            lambda args: _update_at_window_end(*args),
            state,
            identity,
        )
        return state