def test_warmup_adapter(jitted): def find_reasonable_step_size(step_size, m_inv, z, rng_key): return jnp.where(step_size < 1, step_size * 4, step_size / 4) num_steps = 150 adaptation_schedule = build_adaptation_schedule(num_steps) init_step_size = 1. mass_matrix_size = 3 wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size) wa_update = jit(wa_update) if jitted else wa_update rng_key = random.PRNGKey(0) z = jnp.ones(3) wa_state = wa_init((z, None, None, None), rng_key, init_step_size, mass_matrix_size=mass_matrix_size) step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert step_size == find_reasonable_step_size(init_step_size, inverse_mass_matrix, z, rng_key) assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size)) assert window_idx == 0 window = adaptation_schedule[0] for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state) last_step_size = step_size step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 1 # step_size is decreased because accept_prob < target_accept_prob assert step_size < last_step_size # inverse_mass_matrix does not change at the end of the first window assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size)) window = adaptation_schedule[1] window_len = window.end - window.start for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len, 2 * z, wa_state) last_step_size = step_size step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 2 # step_size is increased because accept_prob > target_accept_prob assert step_size > last_step_size # Verifies that inverse_mass_matrix changes at the end of the second window. # Because z_flat is constant during the second window, covariance will be 0 # and only regularize_term of welford scheme is involved. # This also verifies that z_flat terms in the first window does not affect # the second window. welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5)) assert_allclose(inverse_mass_matrix, jnp.full((mass_matrix_size, ), welford_regularize_term), atol=1e-7) window = adaptation_schedule[2] for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.8, t * z, wa_state) last_step_size = step_size step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 3 # during the last window, because target_accept_prob=0.8, # log_step_size will be equal to the constant prox_center=log(10*last_step_size) assert_allclose(step_size, last_step_size * 10, atol=1e-6) # Verifies that inverse_mass_matrix does not change during the last window # despite z_flat changes w.r.t time t, assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)
def test_build_adaptation_schedule(num_steps, expected): adaptation_schedule = build_adaptation_schedule(num_steps) expected_schedule = [AdaptWindow(i, j) for i, j in expected] assert adaptation_schedule == expected_schedule