def sample(self, key, sample_shape=()): return random.exponential( key, shape=sample_shape + self.batch_shape) / self.rate
def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs num_discretes = self._support_sizes_flat.shape[0] def potential_fn(z_gibbs, z_hmc): return self.inner_kernel._potential_fn_gen(*model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc) def update_discrete(idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum): # Algo 1, line 19: get a new discrete proposal rng_key, z_discrete_new, pe_new, log_accept_ratio = self._discrete_proposal_fn( rng_key, z_discrete, hmc_state.potential_energy, partial(potential_fn, z_hmc=hmc_state.z), idx, self._support_sizes_flat[idx]) # Algo 1, line 20: depending on reject or refract, we will update # the discrete variable and its corresponding kinetic energy. In case of # refract, we will need to update the potential energy and its grad w.r.t. hmc_state.z ke_discrete_i_new = ke_discrete[idx] + log_accept_ratio grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad z_discrete, pe, ke_discrete_i, z_grad = lax.cond( ke_discrete_i_new > 0, (z_discrete_new, pe_new, ke_discrete_i_new), lambda vals: vals + (grad_(partial(potential_fn, vals[0])) (hmc_state.z), ), (z_discrete, hmc_state.potential_energy, ke_discrete[idx], hmc_state.z_grad), identity) delta_pe_sum = delta_pe_sum + pe - hmc_state.potential_energy ke_discrete = ops.index_update(ke_discrete, idx, ke_discrete_i) hmc_state = hmc_state._replace(potential_energy=pe, z_grad=z_grad) return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum def update_continuous(hmc_state, z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete hmc_state_new = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_) # each time a sub-trajectory is performed, we need to reset i and adapt_state # (we will only update them at the end of HMCGibbs step) # For `num_steps`, we will record its cumulative sum for diagnostics hmc_state = hmc_state_new._replace( i=hmc_state.i, adapt_state=hmc_state.adapt_state, num_steps=hmc_state.num_steps + hmc_state_new.num_steps) return hmc_state def body_fn(i, vals): rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times = vals idx = jnp.argmin(arrival_times) # NB: length of each sub-trajectory is scaled from the current min(arrival_times) # (see the note at total_time below) trajectory_length = arrival_times[idx] * time_unit arrival_times = arrival_times - arrival_times[idx] arrival_times = ops.index_update(arrival_times, idx, 1.) # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal pe = jnp.inf hmc_state = hmc_state._replace(trajectory_length=trajectory_length, potential_energy=pe) # Algo 1, line 7: perform a sub-trajectory hmc_state = update_continuous(hmc_state, z_discrete) # Algo 1, line 8: perform a discrete update rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete( idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum) return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times z_discrete = { k: v for k, v in state.z.items() if k not in state.hmc_state.z } rng_key, rng_ke, rng_time, rng_r, rng_accept = random.split( state.rng_key, 5) # Algo 1, line 2: sample discrete kinetic energy ke_discrete = random.exponential(rng_ke, (num_discretes, )) # Algo 1, line 4 and 5: sample the initial amount of time that each discrete site visits # the point 0/1. The logic in GetStepSizesNSteps(...) is more complicated but does # the same job: the sub-trajectory length eta_t * M_t is the lag between two arrival time. arrival_times = random.uniform(rng_time, (num_discretes, )) # compute the amount of time to make `num_discrete_updates` discrete updates total_time = (self._num_discrete_updates - 1) // num_discretes \ + jnp.sort(arrival_times)[(self._num_discrete_updates - 1) % num_discretes] # NB: total_time can be different from the HMC trajectory length, so we need to scale # the time unit so that total_time * time_unit = hmc_trajectory_length time_unit = state.hmc_state.trajectory_length / total_time # Algo 1, line 2: sample hmc momentum r = momentum_generator(state.hmc_state.r, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r) hmc_state = state.hmc_state._replace(r=r, num_steps=0) hmc_ke = euclidean_kinetic_energy( hmc_state.adapt_state.inverse_mass_matrix, r) # Algo 1, line 10: compute the initial energy energy_old = hmc_ke + hmc_state.potential_energy # Algo 1, line 3: set initial values delta_pe_sum = 0. init_val = (rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times) # Algo 1, line 6-9: perform the update loop rng_key, hmc_state_new, z_discrete_new, _, delta_pe_sum, _ = fori_loop( 0, self._num_discrete_updates, body_fn, init_val) # Algo 1, line 10: compute the proposal energy hmc_ke = euclidean_kinetic_energy( hmc_state.adapt_state.inverse_mass_matrix, hmc_state_new.r) energy_new = hmc_ke + hmc_state_new.potential_energy # Algo 1, line 11: perform MH correction delta_energy = energy_new - energy_old - delta_pe_sum delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) # record the correct new num_steps hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps) # reset the trajectory length hmc_state_new = hmc_state_new._replace( trajectory_length=hmc_state.trajectory_length) hmc_state, z_discrete = cond(random.bernoulli(rng_key, accept_prob), (hmc_state_new, z_discrete_new), identity, (hmc_state, z_discrete), identity) # perform hmc adapting (similar to the implementation in hmc) adapt_state = cond(hmc_state.i < self._num_warmup, (hmc_state.i, accept_prob, (hmc_state.z, ), hmc_state.adapt_state), lambda args: self._wa_update(*args), hmc_state.adapt_state, identity) itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob_prev = state.hmc_state.mean_accept_prob mean_accept_prob = mean_accept_prob_prev + (accept_prob - mean_accept_prob_prev) / n hmc_state = hmc_state._replace(i=itr, accept_prob=accept_prob, mean_accept_prob=mean_accept_prob, adapt_state=adapt_state) z = {**z_discrete, **hmc_state.z} return MixedHMCState(z, hmc_state, rng_key, accept_prob)