def init_kernel(init_samples, num_warmup_steps, step_size=1.0, num_steps=None, adapt_step_size=True, adapt_mass_matrix=True, diag_mass=True, target_accept_prob=0.8, run_warmup=True, rng=PRNGKey(0)): step_size = float(step_size) nonlocal trajectory_length, momentum_generator, wa_update if num_steps is None: trajectory_length = 2 * math.pi else: trajectory_length = num_steps * step_size z = init_samples z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup_steps, find_reasonable_step_size=find_reasonable_ss, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, diag_mass=diag_mass, target_accept_prob=target_accept_prob) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.inverse_mass_matrix, rng) vv_state = vv_init(z, r) hmc_state = HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., wa_state.step_size, wa_state.inverse_mass_matrix, rng_hmc) if run_warmup: hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update, (hmc_state, wa_state)) return hmc_state else: return hmc_state, wa_state, warmup_update
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2*math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup_steps: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of step size is done at the beginning of each adaptation window to achieve `target_acceptance_prob`. :param jax.random.PRNGKey rng: random key to be used as the source of randomness. """ step_size = float(step_size) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = float(trajectory_length) max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter(num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., wa_state, rng_hmc) if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = sample_kernel(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state
def test_warmup_adapter(jitted): def find_reasonable_step_size(m_inv, z, rng, step_size): return np.where(step_size < 1, step_size * 4, step_size / 4) num_steps = 150 adaptation_schedule = build_adaptation_schedule(num_steps) init_step_size = 1. mass_matrix_size = 3 wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size) wa_update = jit(wa_update) if jitted else wa_update rng = random.PRNGKey(0) z = np.ones(3) wa_state = wa_init(z, rng, init_step_size, mass_matrix_size=mass_matrix_size) step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert step_size == find_reasonable_step_size(inverse_mass_matrix, z, rng, init_step_size) assert_allclose(inverse_mass_matrix, np.ones(mass_matrix_size)) assert window_idx == 0 window = adaptation_schedule[0] for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state) last_step_size = step_size step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 1 # step_size is decreased because accept_prob < target_accept_prob assert step_size < last_step_size # inverse_mass_matrix does not change at the end of the first window assert_allclose(inverse_mass_matrix, np.ones(mass_matrix_size)) window = adaptation_schedule[1] window_len = window.end - window.start for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len, 2 * z, wa_state) last_step_size = step_size step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 2 # step_size is increased because accept_prob > target_accept_prob assert step_size > last_step_size # Verifies that inverse_mass_matrix changes at the end of the second window. # Because z_flat is constant during the second window, covariance will be 0 # and only regularize_term of welford scheme is involved. # This also verifies that z_flat terms in the first window does not affect # the second window. welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5)) assert_allclose(inverse_mass_matrix, np.full((mass_matrix_size,), welford_regularize_term), atol=1e-7) window = adaptation_schedule[2] for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.8, t * z, wa_state) last_step_size = step_size step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 3 # during the last window, because target_accept_prob=0.8, # log_step_size will be equal to the constant prox_center=log(10*last_step_size) assert_allclose(step_size, last_step_size * 10) # Verifies that inverse_mass_matrix does not change during the last window # despite z_flat changes w.r.t time t, assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)