def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe
def _discrete_gibbs_proposal_body_fn(z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val): rng_key, z, pe, log_weight_sum = val rng_key, rng_transition = random.split(rng_key) proposal = jnp.where(i >= z_init_flat[idx], i + 1, i) z_new_flat = ops.index_update(z_init_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_weight_new = pe_init - pe_new # Handles the NaN case... log_weight_new = jnp.where(jnp.isfinite(log_weight_new), log_weight_new, -jnp.inf) # transition_prob = e^weight_new / (e^weight_logsumexp + e^weight_new) transition_prob = expit(log_weight_new - log_weight_sum) z, pe = cond(random.bernoulli(rng_transition, transition_prob), (z_new, pe_new), identity, (z, pe), identity) log_weight_sum = jnp.logaddexp(log_weight_new, log_weight_sum) return rng_key, z, pe, log_weight_sum
def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, rng_gibbs = random.split(state.rng_key) def potential_fn(z_gibbs, gibbs_state, z_hmc): return self.inner_kernel._potential_fn_gen( *model_args, _gibbs_sites=z_gibbs, _gibbs_state=gibbs_state, **model_kwargs, )(z_hmc) z_gibbs = { k: v for k, v in state.z.items() if k not in state.hmc_state.z } z_gibbs_new, gibbs_state_new = self._gibbs_update( rng_key, z_gibbs, state.gibbs_state) # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr pe = state.hmc_state.potential_energy pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z) accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) transition = random.bernoulli(rng_key, accept_prob) grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad z_gibbs, gibbs_state, pe, z_grad = cond( transition, (z_gibbs_new, gibbs_state_new, pe_new), lambda vals: vals + (grad_(partial(potential_fn, vals[0], vals[1])) (state.hmc_state.z), ), (z_gibbs, state.gibbs_state, pe, state.hmc_state.z_grad), identity, ) hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe) model_kwargs["_gibbs_sites"] = z_gibbs model_kwargs["_gibbs_state"] = gibbs_state hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs) z = {**z_gibbs, **hmc_state.z} return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob)
def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng_key): num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), lambda args: args, (vv_state, energy_old), lambda args: args) return vv_state, energy, num_steps, accept_prob, diverging
def sample(self, state, model_args, model_kwargs): """ Given the current `state`, return the next `state` using the given transition kernel. :param state: A `pytree <https://jax.readthedocs.io/en/latest/pytrees.html>`_ class representing the state for the kernel. For HMC, this is given by :data:`~numpyro.infer.hmc.HMCStkernel0ate`. In general, this could be any class that supports `getattr`. :param model_args: Arguments provided to the model. :param model_kwargs: Keyword arguments provided to the model. :return: Next `state`. """ return cond(state.itr % 2 == 0, state, lambda s: self._kernel0.sample( s, model_args, model_kwargs), state, lambda s: self._kernel1.sample(s, model_args, model_kwargs))
def sample_kernel(hmc_state, model_args=(), model_kwargs=None): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. :param hmc_state: Current sample (and associated state). :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating Hamiltonian dynamics given existing state. """ model_kwargs = {} if model_kwargs is None else model_kwargs rng_key, rng_key_momentum, rng_key_transition = random.split( hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) vv_state, energy, num_steps, accept_prob, diverging = _next( hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key_transition) # not update adapt_state after warmup phase adapt_state = cond( hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), lambda args: wa_update(*args), hmc_state.adapt_state, identity) itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + ( accept_prob - hmc_state.mean_accept_prob) / n return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity) return vv_state, energy, num_steps, accept_prob, diverging
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): assert set(gibbs_sites) == set(plate_sizes) u_new = {} for name in gibbs_sites: size, subsample_size = plate_sizes[name] rng_key, subkey, block_key = random.split(rng_key, 3) block_size = subsample_size // num_blocks chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) new_idx = random.randint(subkey, minval=0, maxval=size, shape=(subsample_size, )) block_mask = jnp.arange( subsample_size) // block_size == chosen_block u_new[name] = jnp.where(block_mask, new_idx, gibbs_sites[name]) u_loglik = log_likelihood(_wrap_model(model), hmc_sites, *model_args, batch_ndims=0, **model_kwargs, _gibbs_sites=gibbs_sites) u_loglik = sum(v.sum() for v in u_loglik.values()) u_new_loglik = log_likelihood(_wrap_model(model), hmc_sites, *model_args, batch_ndims=0, **model_kwargs, _gibbs_sites=u_new) u_new_loglik = sum(v.sum() for v in u_new_loglik.values()) accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0) return cond(random.bernoulli(rng_key, accept_prob), u_new, identity, gibbs_sites, identity)
def _body_fn(state): current_tree, _, r_ckpts, r_sum_ckpts, rng = state rng, transition_rng = random.split(rng) z, r, z_grad = _get_leaf(current_tree, going_right) new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy) new_tree = _combine_tree(current_tree, new_leaf, inverse_mass_matrix, going_right, transition_rng, False) leaf_idx = current_tree.num_proposals ckpt_idx_min, ckpt_idx_max = _leaf_idx_to_ckpt_idxs(leaf_idx) r, _ = ravel_pytree(new_leaf.r_right) r_sum, _ = ravel_pytree(new_tree.r_sum) # we update checkpoints when leaf_idx is even r_ckpts, r_sum_ckpts = cond(leaf_idx % 2 == 0, (r_ckpts, r_sum_ckpts), lambda x: (index_update(x[0], ckpt_idx_max, r), index_update(x[1], ckpt_idx_max, r_sum)), (r_ckpts, r_sum_ckpts), lambda x: x) turning = _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts, ckpt_idx_min, ckpt_idx_max) return new_tree, turning, r_ckpts, r_sum_ckpts, rng
def _discrete_modified_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.): assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1 z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) body_fn = partial(_discrete_gibbs_proposal_body_fn, z_discrete_flat, unravel_fn, pe, potential_fn, idx) # like gibbs_step but here, weight of the current value is 0 init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf)) rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1, body_fn, init_val) rng_key, rng_stay = random.split(rng_key) z_new, pe_new = cond(random.bernoulli(rng_stay, stay_prob), (z_discrete, pe), identity, (z_new, pe_new), identity) # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new)) # where 1 - P(z) ~ weight_sum # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight log_accept_ratio = log_weight_sum - jnp.log( jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new)) return rng_key, z_new, pe_new, log_accept_ratio
def _get_leaf(tree, going_right): return cond(going_right, tree, lambda tree: (tree.z_right, tree.r_right, tree.z_right_grad), tree, lambda tree: (tree.z_left, tree.r_left, tree.z_left_grad))
def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs num_discretes = self._support_sizes_flat.shape[0] def potential_fn(z_gibbs, z_hmc): return self.inner_kernel._potential_fn_gen(*model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc) def update_discrete(idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum): # Algo 1, line 19: get a new discrete proposal ( rng_key, z_discrete_new, pe_new, log_accept_ratio, ) = self._discrete_proposal_fn( rng_key, z_discrete, hmc_state.potential_energy, partial(potential_fn, z_hmc=hmc_state.z), idx, self._support_sizes_flat[idx], ) # Algo 1, line 20: depending on reject or refract, we will update # the discrete variable and its corresponding kinetic energy. In case of # refract, we will need to update the potential energy and its grad w.r.t. hmc_state.z ke_discrete_i_new = ke_discrete[idx] + log_accept_ratio grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad z_discrete, pe, ke_discrete_i, z_grad = lax.cond( ke_discrete_i_new > 0, (z_discrete_new, pe_new, ke_discrete_i_new), lambda vals: vals + (grad_(partial(potential_fn, vals[0])) (hmc_state.z), ), ( z_discrete, hmc_state.potential_energy, ke_discrete[idx], hmc_state.z_grad, ), identity, ) delta_pe_sum = delta_pe_sum + pe - hmc_state.potential_energy ke_discrete = ops.index_update(ke_discrete, idx, ke_discrete_i) hmc_state = hmc_state._replace(potential_energy=pe, z_grad=z_grad) return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum def update_continuous(hmc_state, z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete hmc_state_new = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_) # each time a sub-trajectory is performed, we need to reset i and adapt_state # (we will only update them at the end of HMCGibbs step) # For `num_steps`, we will record its cumulative sum for diagnostics hmc_state = hmc_state_new._replace( i=hmc_state.i, adapt_state=hmc_state.adapt_state, num_steps=hmc_state.num_steps + hmc_state_new.num_steps, ) return hmc_state def body_fn(i, vals): ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) = vals idx = jnp.argmin(arrival_times) # NB: length of each sub-trajectory is scaled from the current min(arrival_times) # (see the note at total_time below) trajectory_length = arrival_times[idx] * time_unit arrival_times = arrival_times - arrival_times[idx] arrival_times = ops.index_update(arrival_times, idx, 1.0) # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal pe = jnp.inf hmc_state = hmc_state._replace(trajectory_length=trajectory_length, potential_energy=pe) # Algo 1, line 7: perform a sub-trajectory hmc_state = update_continuous(hmc_state, z_discrete) # Algo 1, line 8: perform a discrete update rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete( idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum) return ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) z_discrete = { k: v for k, v in state.z.items() if k not in state.hmc_state.z } rng_key, rng_ke, rng_time, rng_r, rng_accept = random.split( state.rng_key, 5) # Algo 1, line 2: sample discrete kinetic energy ke_discrete = random.exponential(rng_ke, (num_discretes, )) # Algo 1, line 4 and 5: sample the initial amount of time that each discrete site visits # the point 0/1. The logic in GetStepSizesNSteps(...) is more complicated but does # the same job: the sub-trajectory length eta_t * M_t is the lag between two arrival time. arrival_times = random.uniform(rng_time, (num_discretes, )) # compute the amount of time to make `num_discrete_updates` discrete updates total_time = (self._num_discrete_updates - 1) // num_discretes + jnp.sort(arrival_times)[ (self._num_discrete_updates - 1) % num_discretes] # NB: total_time can be different from the HMC trajectory length, so we need to scale # the time unit so that total_time * time_unit = hmc_trajectory_length time_unit = state.hmc_state.trajectory_length / total_time # Algo 1, line 2: sample hmc momentum r = momentum_generator(state.hmc_state.r, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r) hmc_state = state.hmc_state._replace(r=r, num_steps=0) hmc_ke = euclidean_kinetic_energy( hmc_state.adapt_state.inverse_mass_matrix, r) # Algo 1, line 10: compute the initial energy energy_old = hmc_ke + hmc_state.potential_energy # Algo 1, line 3: set initial values delta_pe_sum = 0.0 init_val = ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) # Algo 1, line 6-9: perform the update loop rng_key, hmc_state_new, z_discrete_new, _, delta_pe_sum, _ = fori_loop( 0, self._num_discrete_updates, body_fn, init_val) # Algo 1, line 10: compute the proposal energy hmc_ke = euclidean_kinetic_energy( hmc_state.adapt_state.inverse_mass_matrix, hmc_state_new.r) energy_new = hmc_ke + hmc_state_new.potential_energy # Algo 1, line 11: perform MH correction delta_energy = energy_new - energy_old - delta_pe_sum delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) # record the correct new num_steps hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps) # reset the trajectory length hmc_state_new = hmc_state_new._replace( trajectory_length=hmc_state.trajectory_length) hmc_state, z_discrete = cond( random.bernoulli(rng_key, accept_prob), (hmc_state_new, z_discrete_new), identity, (hmc_state, z_discrete), identity, ) # perform hmc adapting (similar to the implementation in hmc) adapt_state = cond( hmc_state.i < self._num_warmup, (hmc_state.i, accept_prob, (hmc_state.z, ), hmc_state.adapt_state), lambda args: self._wa_update(*args), hmc_state.adapt_state, identity, ) itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob_prev = state.hmc_state.mean_accept_prob mean_accept_prob = (mean_accept_prob_prev + (accept_prob - mean_accept_prob_prev) / n) hmc_state = hmc_state._replace( i=itr, accept_prob=accept_prob, mean_accept_prob=mean_accept_prob, adapt_state=adapt_state, ) z = {**z_discrete, **hmc_state.z} return MixedHMCState(z, hmc_state, rng_key, accept_prob)
def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition): # Now we combine the current tree and the new tree. Note that outside # leaves of the combined tree are determined by the direction. z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond( going_right, (current_tree, new_tree), lambda trees: ( trees[0].z_left, trees[0].r_left, trees[0].z_left_grad, trees[1].z_right, trees[1].r_right, trees[1].z_right_grad, ), (new_tree, current_tree), lambda trees: ( trees[0].z_left, trees[0].r_left, trees[0].z_left_grad, trees[1].z_right, trees[1].r_right, trees[1].z_right_grad, ), ) r_sum = tree_multimap(jnp.add, current_tree.r_sum, new_tree.r_sum) if biased_transition: transition_prob = _biased_transition_kernel(current_tree, new_tree) turning = new_tree.turning | _is_turning(inverse_mass_matrix, r_left, r_right, r_sum) else: transition_prob = _uniform_transition_kernel(current_tree, new_tree) turning = current_tree.turning transition = random.bernoulli(rng_key, transition_prob) z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond( transition, new_tree, lambda tree: ( tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy, ), current_tree, lambda tree: ( tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy, ), ) tree_depth = current_tree.depth + 1 tree_weight = jnp.logaddexp(current_tree.weight, new_tree.weight) diverging = new_tree.diverging sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs num_proposals = current_tree.num_proposals + new_tree.num_proposals return TreeInfo( z_left, r_left, z_left_grad, z_right, r_right, r_right_grad, z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy, tree_depth, tree_weight, r_sum, turning, diverging, sum_accept_probs, num_proposals, )
def update_fn(t, accept_prob, z_info, state): """ :param int t: The current time step. :param float accept_prob: Acceptance probability of the current trajectory. :param IntegratorState z_info: The new integrator state. :param state: Current state of the adapt scheme. :return: new state of the adapt scheme. """ ( step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, ss_state, mm_state, window_idx, rng_key, ) = state if rng_key is not None: rng_key, rng_key_ss = random.split(rng_key) else: rng_key_ss = None # update step size state if adapt_step_size: ss_state = ss_update(target_accept_prob - accept_prob, ss_state) # note: at the end of warmup phase, use average of log step_size log_step_size, log_step_size_avg, *_ = ss_state step_size = jnp.where( t == (num_adapt_steps - 1), jnp.exp(log_step_size_avg), jnp.exp(log_step_size), ) # account the the case log_step_size is an extreme number finfo = jnp.finfo(jnp.result_type(step_size)) step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max) # update mass matrix state is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1)) if adapt_mass_matrix: z = z_info[0] mm_state = cond( is_middle_window, (z, mm_state), lambda args: mm_update(*args), mm_state, identity, ) t_at_window_end = t == adaptation_schedule[window_idx, 1] window_idx = jnp.where(t_at_window_end, window_idx + 1, window_idx) state = HMCAdaptState( step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, ss_state, mm_state, window_idx, rng_key, ) state = cond( t_at_window_end & is_middle_window, (z_info, rng_key_ss, state), lambda args: _update_at_window_end(*args), state, identity, ) return state