Пример #1
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     probs = self.probs
     dtype = get_dtype(probs)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
Пример #2
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     logits = self.logits
     dtype = get_dtype(logits)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / -softplus(logits))
Пример #3
0
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix,
                              position, rng, init_step_size):
    """
    Finds a reasonable step size by tuning `init_step_size`. This function is used
    to avoid working with a too large or too small step size in HMC.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param position: Current position of the particle.
    :param jax.random.PRNGKey rng: Random key to be used as the source of randomness.
    :param float init_step_size: Initial step size to be tuned.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = np.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)
    tiny = np.finfo(get_dtype(init_step_size)).tiny

    def _body_fn(state):
        step_size, _, direction, rng = state
        rng, rng_momentum = random.split(rng)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0 ** direction) * step_size
        r = momentum_generator(inverse_mass_matrix, rng_momentum)
        _, r_new, potential_energy_new, _ = vv_update(step_size,
                                                      inverse_mass_matrix,
                                                      (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = np.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng

    def _cond_fn(state):
        step_size, last_direction, direction, _ = state
        # condition to run only if step_size is not so small or we are not decreasing step_size
        not_small_step_size_cond = (step_size > tiny) | (direction >= 0)
        return not_small_step_size_cond & ((last_direction == 0) | (direction == last_direction))

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng))
    return step_size
Пример #4
0
    def update_fn(t, accept_prob, z_info, state):
        """
        :param int t: The current time step.
        :param float accept_prob: Acceptance probability of the current trajectory.
        :param IntegratorState z_info: The new integrator state.
        :param state: Current state of the adapt scheme.
        :return: new state of the adapt scheme.
        """
        step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, \
            ss_state, mm_state, window_idx, rng_key = state
        if rng_key is not None:
            rng_key, rng_key_ss = random.split(rng_key)
        else:
            rng_key_ss = None

        # update step size state
        if adapt_step_size:
            ss_state = ss_update(target_accept_prob - accept_prob, ss_state)
            # note: at the end of warmup phase, use average of log step_size
            log_step_size, log_step_size_avg, *_ = ss_state
            step_size = jnp.where(t == (num_adapt_steps - 1),
                                  jnp.exp(log_step_size_avg),
                                  jnp.exp(log_step_size))
            # account the the case log_step_size is an extreme number
            finfo = jnp.finfo(get_dtype(step_size))
            step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max)

        # update mass matrix state
        is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1))
        if adapt_mass_matrix:
            z = z_info[0]
            z_flat, _ = ravel_pytree(z)
            mm_state = cond(is_middle_window, (z_flat, mm_state),
                            lambda args: mm_update(*args), mm_state, identity)

        t_at_window_end = t == adaptation_schedule[window_idx, 1]
        window_idx = jnp.where(t_at_window_end, window_idx + 1, window_idx)
        state = HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                              mass_matrix_sqrt_inv, ss_state, mm_state,
                              window_idx, rng_key)
        state = cond(t_at_window_end & is_middle_window,
                     (z_info, rng_key_ss, state),
                     lambda args: _update_at_window_end(*args), state,
                     identity)
        return state
Пример #5
0
    def update_fn(t, accept_prob, z, state):
        """
        :param int t: The current time step.
        :param float accept_prob: Acceptance probability of the current trajectory.
        :param z: New position drawn at the end of the current trajectory.
        :param state: Current state of the adapt scheme.
        :return: new state of the adapt scheme.
        """
        step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng = state
        rng, rng_ss = random.split(rng)

        # update step size state
        if adapt_step_size:
            ss_state = ss_update(target_accept_prob - accept_prob, ss_state)
            # note: at the end of warmup phase, use average of log step_size
            log_step_size, log_step_size_avg, *_ = ss_state
            step_size = np.where(t == (num_adapt_steps - 1),
                                 np.exp(log_step_size_avg),
                                 np.exp(log_step_size))
            # account the the case log_step_size is a so small negative number
            step_size = np.clip(step_size, a_min=np.finfo(get_dtype(step_size)).tiny)

        # update mass matrix state
        is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1))
        if adapt_mass_matrix:
            z_flat, _ = ravel_pytree(z)
            mm_state = cond(is_middle_window,
                            (z_flat, mm_state), lambda args: mm_update(*args),
                            mm_state, lambda x: x)

        t_at_window_end = t == adaptation_schedule[window_idx, 1]
        window_idx = np.where(t_at_window_end, window_idx + 1, window_idx)
        state = AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                           ss_state, mm_state, window_idx, rng)
        state = cond(t_at_window_end & is_middle_window,
                     (z, rng_ss, state), lambda args: _update_at_window_end(*args),
                     state, lambda x: x)
        return state
Пример #6
0
def _clipped_expit(x):
    finfo = jnp.finfo(get_dtype(x))
    return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps)
Пример #7
0
def _to_logits_multinom(probs):
    minval = jnp.finfo(get_dtype(probs)).min
    return jnp.clip(jnp.log(probs), a_min=minval)
Пример #8
0
 def variance(self):
     return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.logits))
Пример #9
0
 def mean(self):
     return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs))
Пример #10
0
 def mean(self):
     return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.logits))
Пример #11
0
 def variance(self):
     return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs))