Exemple #1
0
def test_hmc(inv_mass_matrix):
    """Test the HMC kernel.

    This is a very simple sanity-check.
    """
    x_data = np.random.normal(0, 1, size=(1000, 1))
    y_data = 3 * x_data + np.random.normal(size=x_data.shape)
    observations = {"x": x_data, "preds": y_data}

    conditioned_potential = ft.partial(potential_fn, **observations)
    potential = lambda x: conditioned_potential(**x)

    initial_position = {"scale": 0.5, "coefs": 2.0}
    initial_state = hmc.new_state(initial_position, potential)

    params = hmc.HMCParameters(num_integration_steps=90,
                               step_size=1e-3,
                               inv_mass_matrix=inv_mass_matrix)
    kernel = hmc.kernel(potential, params)

    rng_key = jax.random.PRNGKey(19)
    states = inference_loop(rng_key, kernel, initial_state, 20_000)

    coefs_samples = states.position["coefs"][5000:]
    scale_samples = states.position["scale"][5000:]

    assert np.mean(scale_samples) == pytest.approx(1, 1e-1)
    assert np.mean(coefs_samples) == pytest.approx(3, 1e-1)
Exemple #2
0
def test_nuts():
    rng_key = jax.random.PRNGKey(0)
    state = hmc.new_state(1.0, potential)
    params = hmc.HMCParameters(inv_mass_matrix=jnp.array([1.0]))

    GLOBAL["count"] = 0
    kernel = jax.jit(nuts.kernel(potential, params))

    for _ in range(10):
        _, rng_key = jax.random.split(rng_key)
        state, _ = kernel(rng_key, state)

    assert GLOBAL["count"] == 1
Exemple #3
0
def test_hmc():
    """The reason why this works is because JAX only reads the potential once when compiled?"""
    rng_key = jax.random.PRNGKey(0)
    state = hmc.new_state(1.0, potential)
    params = hmc.HMCParameters(inv_mass_matrix=jnp.array([1.0]))

    GLOBAL["count"] = 0
    kernel = jax.jit(hmc.kernel(potential, params))

    for _ in range(10):
        _, rng_key = jax.random.split(rng_key)
        state, _ = kernel(rng_key, state)

    assert GLOBAL["count"] == 1
Exemple #4
0
def test_nuts():
    rng_key = jax.random.PRNGKey(0)
    state = hmc.new_state(1.0, potential)
    params = hmc.HMCParameters(inv_mass_matrix=jnp.array([1.0]))

    GLOBAL["count"] = 0
    kernel = jax.jit(nuts.kernel(potential, params))

    for _ in range(10):
        _, rng_key = jax.random.split(rng_key)
        state, _ = kernel(rng_key, state)

    # Potential function was traced twice as we call potential function
    # at Step 0 when building a new trajectory in tree doubling.
    assert GLOBAL["count"] == 2