def test_hmc(inv_mass_matrix): """Test the HMC kernel. This is a very simple sanity-check. """ x_data = np.random.normal(0, 1, size=(1000, 1)) y_data = 3 * x_data + np.random.normal(size=x_data.shape) observations = {"x": x_data, "preds": y_data} conditioned_potential = ft.partial(potential_fn, **observations) potential = lambda x: conditioned_potential(**x) initial_position = {"scale": 0.5, "coefs": 2.0} initial_state = hmc.new_state(initial_position, potential) params = hmc.HMCParameters(num_integration_steps=90, step_size=1e-3, inv_mass_matrix=inv_mass_matrix) kernel = hmc.kernel(potential, params) rng_key = jax.random.PRNGKey(19) states = inference_loop(rng_key, kernel, initial_state, 20_000) coefs_samples = states.position["coefs"][5000:] scale_samples = states.position["scale"][5000:] assert np.mean(scale_samples) == pytest.approx(1, 1e-1) assert np.mean(coefs_samples) == pytest.approx(3, 1e-1)
def test_nuts(): rng_key = jax.random.PRNGKey(0) state = hmc.new_state(1.0, potential) params = hmc.HMCParameters(inv_mass_matrix=jnp.array([1.0])) GLOBAL["count"] = 0 kernel = jax.jit(nuts.kernel(potential, params)) for _ in range(10): _, rng_key = jax.random.split(rng_key) state, _ = kernel(rng_key, state) assert GLOBAL["count"] == 1
def test_hmc(): """The reason why this works is because JAX only reads the potential once when compiled?""" rng_key = jax.random.PRNGKey(0) state = hmc.new_state(1.0, potential) params = hmc.HMCParameters(inv_mass_matrix=jnp.array([1.0])) GLOBAL["count"] = 0 kernel = jax.jit(hmc.kernel(potential, params)) for _ in range(10): _, rng_key = jax.random.split(rng_key) state, _ = kernel(rng_key, state) assert GLOBAL["count"] == 1
def test_nuts(): rng_key = jax.random.PRNGKey(0) state = hmc.new_state(1.0, potential) params = hmc.HMCParameters(inv_mass_matrix=jnp.array([1.0])) GLOBAL["count"] = 0 kernel = jax.jit(nuts.kernel(potential, params)) for _ in range(10): _, rng_key = jax.random.split(rng_key) state, _ = kernel(rng_key, state) # Potential function was traced twice as we call potential function # at Step 0 when building a new trajectory in tree doubling. assert GLOBAL["count"] == 2