Example #1
0
def test_warmup_adapter(jitted):
    def find_reasonable_step_size(m_inv, z, rng, step_size):
        return np.where(step_size < 1, step_size * 4, step_size / 4)

    num_steps = 150
    adaptation_schedule = build_adaptation_schedule(num_steps)
    init_step_size = 1.
    mass_matrix_size = 3

    wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
    wa_update = jit(wa_update) if jitted else wa_update

    rng = random.PRNGKey(0)
    z = np.ones(3)
    wa_state = wa_init(z, rng, init_step_size, mass_matrix_size=mass_matrix_size)
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert step_size == find_reasonable_step_size(inverse_mass_matrix, z, rng, init_step_size)
    assert_allclose(inverse_mass_matrix, np.ones(mass_matrix_size))
    assert window_idx == 0

    window = adaptation_schedule[0]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 1
    # step_size is decreased because accept_prob < target_accept_prob
    assert step_size < last_step_size
    # inverse_mass_matrix does not change at the end of the first window
    assert_allclose(inverse_mass_matrix, np.ones(mass_matrix_size))

    window = adaptation_schedule[1]
    window_len = window.end - window.start
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len, 2 * z, wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 2
    # step_size is increased because accept_prob > target_accept_prob
    assert step_size > last_step_size
    # Verifies that inverse_mass_matrix changes at the end of the second window.
    # Because z_flat is constant during the second window, covariance will be 0
    # and only regularize_term of welford scheme is involved.
    # This also verifies that z_flat terms in the first window does not affect
    # the second window.
    welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5))
    assert_allclose(inverse_mass_matrix,
                    np.full((mass_matrix_size,), welford_regularize_term),
                    atol=1e-7)

    window = adaptation_schedule[2]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8, t * z, wa_state)
    last_step_size = step_size
    step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 3
    # during the last window, because target_accept_prob=0.8,
    # log_step_size will be equal to the constant prox_center=log(10*last_step_size)
    assert_allclose(step_size, last_step_size * 10)
    # Verifies that inverse_mass_matrix does not change during the last window
    # despite z_flat changes w.r.t time t,
    assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)
Example #2
0
def test_build_adaptation_schedule(num_steps, expected):
    adaptation_schedule = build_adaptation_schedule(num_steps)
    expected_schedule = [AdaptWindow(i, j) for i, j in expected]
    assert adaptation_schedule == expected_schedule