Beispiel #1
0
def test_warmup_adapter(jitted):
    def find_reasonable_step_size(step_size, m_inv, z, rng_key):
        return jnp.where(step_size < 1, step_size * 4, step_size / 4)

    num_steps = 150
    adaptation_schedule = build_adaptation_schedule(num_steps)
    init_step_size = 1.
    mass_matrix_size = 3

    wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
    wa_update = jit(wa_update) if jitted else wa_update

    rng_key = random.PRNGKey(0)
    z = jnp.ones(3)
    wa_state = wa_init((z, None, None, None),
                       rng_key,
                       init_step_size,
                       mass_matrix_size=mass_matrix_size)
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert step_size == find_reasonable_step_size(init_step_size,
                                                  inverse_mass_matrix, z,
                                                  rng_key)
    assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))
    assert window_idx == 0

    window = adaptation_schedule[0]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z,
                             wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 1
    # step_size is decreased because accept_prob < target_accept_prob
    assert step_size < last_step_size
    # inverse_mass_matrix does not change at the end of the first window
    assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))

    window = adaptation_schedule[1]
    window_len = window.end - window.start
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len,
                             2 * z, wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 2
    # step_size is increased because accept_prob > target_accept_prob
    assert step_size > last_step_size
    # Verifies that inverse_mass_matrix changes at the end of the second window.
    # Because z_flat is constant during the second window, covariance will be 0
    # and only regularize_term of welford scheme is involved.
    # This also verifies that z_flat terms in the first window does not affect
    # the second window.
    welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5))
    assert_allclose(inverse_mass_matrix,
                    jnp.full((mass_matrix_size, ), welford_regularize_term),
                    atol=1e-7)

    window = adaptation_schedule[2]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8, t * z, wa_state)
    last_step_size = step_size
    step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 3
    # during the last window, because target_accept_prob=0.8,
    # log_step_size will be equal to the constant prox_center=log(10*last_step_size)
    assert_allclose(step_size, last_step_size * 10, atol=1e-6)
    # Verifies that inverse_mass_matrix does not change during the last window
    # despite z_flat changes w.r.t time t,
    assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)
Beispiel #2
0
def test_build_adaptation_schedule(num_steps, expected):
    adaptation_schedule = build_adaptation_schedule(num_steps)
    expected_schedule = [AdaptWindow(i, j) for i, j in expected]
    assert adaptation_schedule == expected_schedule