def sample(self, key, sample_shape=()): assert is_prng_key(key) probs = self.probs dtype = get_dtype(probs) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits dtype = get_dtype(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits))
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix, position, rng, init_step_size): """ Finds a reasonable step size by tuning `init_step_size`. This function is used to avoid working with a too large or too small step size in HMC. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman :param potential_fn: A callable to compute potential energy. :param kinetic_fn: A callable to compute kinetic energy. :param momentum_generator: A generator to get a random momentum variable. :param inverse_mass_matrix: Inverse of mass matrix. :param position: Current position of the particle. :param jax.random.PRNGKey rng: Random key to be used as the source of randomness. :param float init_step_size: Initial step size to be tuned. :return: a reasonable value for step size. :rtype: float """ # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. target_accept_prob = np.log(0.8) _, vv_update = velocity_verlet(potential_fn, kinetic_fn) z = position potential_energy, z_grad = value_and_grad(potential_fn)(z) tiny = np.finfo(get_dtype(init_step_size)).tiny def _body_fn(state): step_size, _, direction, rng = state rng, rng_momentum = random.split(rng) # scale step_size: increase 2x or decrease 2x depends on direction; # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN`, which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). step_size = (2.0 ** direction) * step_size r = momentum_generator(inverse_mass_matrix, rng_momentum) _, r_new, potential_energy_new, _ = vv_update(step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad)) energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new delta_energy = energy_new - energy_current direction_new = np.where(target_accept_prob < -delta_energy, 1, -1) return step_size, direction, direction_new, rng def _cond_fn(state): step_size, last_direction, direction, _ = state # condition to run only if step_size is not so small or we are not decreasing step_size not_small_step_size_cond = (step_size > tiny) | (direction >= 0) return not_small_step_size_cond & ((last_direction == 0) | (direction == last_direction)) step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng)) return step_size
def update_fn(t, accept_prob, z_info, state): """ :param int t: The current time step. :param float accept_prob: Acceptance probability of the current trajectory. :param IntegratorState z_info: The new integrator state. :param state: Current state of the adapt scheme. :return: new state of the adapt scheme. """ step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, \ ss_state, mm_state, window_idx, rng_key = state if rng_key is not None: rng_key, rng_key_ss = random.split(rng_key) else: rng_key_ss = None # update step size state if adapt_step_size: ss_state = ss_update(target_accept_prob - accept_prob, ss_state) # note: at the end of warmup phase, use average of log step_size log_step_size, log_step_size_avg, *_ = ss_state step_size = jnp.where(t == (num_adapt_steps - 1), jnp.exp(log_step_size_avg), jnp.exp(log_step_size)) # account the the case log_step_size is an extreme number finfo = jnp.finfo(get_dtype(step_size)) step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max) # update mass matrix state is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1)) if adapt_mass_matrix: z = z_info[0] z_flat, _ = ravel_pytree(z) mm_state = cond(is_middle_window, (z_flat, mm_state), lambda args: mm_update(*args), mm_state, identity) t_at_window_end = t == adaptation_schedule[window_idx, 1] window_idx = jnp.where(t_at_window_end, window_idx + 1, window_idx) state = HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, ss_state, mm_state, window_idx, rng_key) state = cond(t_at_window_end & is_middle_window, (z_info, rng_key_ss, state), lambda args: _update_at_window_end(*args), state, identity) return state
def update_fn(t, accept_prob, z, state): """ :param int t: The current time step. :param float accept_prob: Acceptance probability of the current trajectory. :param z: New position drawn at the end of the current trajectory. :param state: Current state of the adapt scheme. :return: new state of the adapt scheme. """ step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng = state rng, rng_ss = random.split(rng) # update step size state if adapt_step_size: ss_state = ss_update(target_accept_prob - accept_prob, ss_state) # note: at the end of warmup phase, use average of log step_size log_step_size, log_step_size_avg, *_ = ss_state step_size = np.where(t == (num_adapt_steps - 1), np.exp(log_step_size_avg), np.exp(log_step_size)) # account the the case log_step_size is a so small negative number step_size = np.clip(step_size, a_min=np.finfo(get_dtype(step_size)).tiny) # update mass matrix state is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1)) if adapt_mass_matrix: z_flat, _ = ravel_pytree(z) mm_state = cond(is_middle_window, (z_flat, mm_state), lambda args: mm_update(*args), mm_state, lambda x: x) t_at_window_end = t == adaptation_schedule[window_idx, 1] window_idx = np.where(t_at_window_end, window_idx + 1, window_idx) state = AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng) state = cond(t_at_window_end & is_middle_window, (z, rng_ss, state), lambda args: _update_at_window_end(*args), state, lambda x: x) return state
def _clipped_expit(x): finfo = jnp.finfo(get_dtype(x)) return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps)
def _to_logits_multinom(probs): minval = jnp.finfo(get_dtype(probs)).min return jnp.clip(jnp.log(probs), a_min=minval)
def variance(self): return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.logits))
def mean(self): return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs))
def mean(self): return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.logits))
def variance(self): return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs))