def test_gaussian_euclidean_ndim_invalid(shape): """Test Gaussian Euclidean Function returns correct function invalid ndim""" x = jnp.ones(shape=shape) with pytest.raises(ValueError) as e: metrics.gaussian_euclidean(x) assert "The mass matrix has the wrong number of dimensions" in str(e)
def kernel(potential_fn: Callable, parameters: HMCParameters) -> Callable: """Build a HMC kernel. Parameters ---------- potential_fn A function that returns the potential energy of a chain at a given position. parameters A NamedTuple that contains the parameters of the kernel to be built. """ step_size, num_integration_steps, inv_mass_matrix, divergence_threshold = parameters if inv_mass_matrix is None: raise ValueError("Expected a value for `inv_mass_matrix`," " got None. Please specify a value when initializing" " the parameters or run the window adaptation.") momentum_generator, kinetic_energy_fn = metrics.gaussian_euclidean( inv_mass_matrix) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) proposal = proposals.hmc(integrator, step_size, num_integration_steps) kernel = base.hmc( proposal, momentum_generator, kinetic_energy_fn, divergence_threshold, ) return kernel
def kernel(potential_fn: Callable, parameters: NUTSParameters) -> Callable: """Build an iterative NUTS kernel. Parameters ---------- potential_fn A function that returns the potential energy of a chain at a given position. The potential energy is defined as minus the log-probability. parameters A NamedTuple that contains the parameters of the kernel to be built. """ step_size, max_tree_depth, inv_mass_matrix, divergence_threshold = parameters if inv_mass_matrix is None: raise ValueError("Expected a value for `inv_mass_matrix`," " got None. Please specify a value when initializing" " the parameters or run the window adaptation.") momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inv_mass_matrix) symplectic_integrator = integrators.velocity_verlet( potential_fn, kinetic_energy_fn) proposal_generator = iterative_nuts_proposal( symplectic_integrator, kinetic_energy_fn, uturn_check_fn, step_size, max_tree_depth, divergence_threshold, ) kernel = base.hmc(momentum_generator, proposal_generator) return kernel
def test_dynamic_progressive_integration_divergence(case): rng_key = jax.random.PRNGKey(0) def potential_fn(x): return jax.scipy.stats.norm.logpdf(x) step_size, should_diverge = case position = 1.0 inverse_mass_matrix = jnp.array([1.0]) momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inverse_mass_matrix) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, divergence_threshold, ) # Initialize direction = 1 initial_state = integrators.new_integrator_state( potential_fn, position, momentum_generator(rng_key, position)) initial_energy = initial_state.potential_energy + kinetic_energy_fn( initial_state.position, initial_state.momentum) termination_state = new_criterion_state(initial_state, 10) max_num_steps = 100 _, _, _, is_diverging, _, _ = trajectory_integrator( rng_key, initial_state, direction, termination_state, max_num_steps, step_size, initial_energy, ) assert is_diverging.item() is should_diverge
def test_dynamic_progressive_expansion(case): rng_key = jax.random.PRNGKey(0) def potential_fn(x): return 0.5 * x ** 2 step_size, should_diverge, should_turn, expected_doublings = case position = 0.0 inverse_mass_matrix = jnp.array([1.0]) momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inverse_mass_matrix ) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, divergence_threshold, ) expand = trajectory.dynamic_multiplicative_expansion( trajectory_integrator, uturn_check_fn, step_size ) state = integrators.new_integrator_state( potential_fn, position, momentum_generator(rng_key, position) ) energy = state.potential_energy + kinetic_energy_fn(state.position, state.momentum) initial_proposal = initial_proposal = proposal.Proposal(state, energy, 0.0) initial_termination_state = new_criterion_state(state, 10) _, _, step, is_diverging, has_terminated, is_turning = expand( rng_key, initial_proposal, initial_termination_state ) assert is_diverging == should_diverge assert step == expected_doublings assert is_turning == should_turn
def test_gaussian_euclidean_dim_2(): """Test Gaussian Euclidean Function with ndim 2""" inverse_mass_matrix = jnp.asarray([[1 / 9, 0], [0, 1 / 4]], dtype=DTYPE) momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix) arbitrary_position = jnp.asarray([12345, 23456], dtype=DTYPE) momentum_val = momentum(KEY, arbitrary_position) # 2 is square root inverse of 1/4 # -0.20584235 is random value returned with random key expected_momentum_val = jnp.asarray([3, 2]) * jnp.asarray([-0.784766, 0.8564448]) kinetic_energy_val = kinetic_energy(momentum_val) velocity = jnp.dot(inverse_mass_matrix, momentum_val) expected_kinetic_energy_val = 0.5 * jnp.matmul(velocity, momentum_val) assert pytest.approx(expected_momentum_val, momentum_val) assert kinetic_energy_val == expected_kinetic_energy_val
def test_gaussian_euclidean_dim_1(): """Test Gaussian Euclidean Function with ndim 1""" inverse_mass_matrix = jnp.asarray([1 / 4], dtype=DTYPE) momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix) arbitrary_position = jnp.asarray([12345], dtype=DTYPE) momentum_val = momentum(KEY, arbitrary_position) # 2 is square root inverse of 1/4 # -0.20584235 is random value returned with random key expected_momentum_val = 2 * -0.20584235 kinetic_energy_val = kinetic_energy(momentum_val) velocity = inverse_mass_matrix * momentum_val expected_kinetic_energy_val = 0.5 * velocity * momentum_val assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val
def test_is_iterative_turning(checkpoint_idxs, expected_turning): inverse_mass_matrix = jnp.ones(1) _, _, is_turning = gaussian_euclidean(inverse_mass_matrix) _, _, is_iterative_turning = iterative_uturn_numpyro(is_turning) momentum = 1.0 momentum_sum = 3.0 idx_min, idx_max = checkpoint_idxs momentum_ckpts = jnp.array([1.0, 2.0, 3.0, -2.0]) momentum_sum_ckpts = jnp.array([2.0, 4.0, 4.0, -1.0]) checkpoints = IterativeUTurnState( momentum_ckpts, momentum_sum_ckpts, idx_min, idx_max, ) actual_turning = is_iterative_turning(checkpoints, momentum_sum, momentum) assert expected_turning == actual_turning
def kernel(potential_fn: Callable, parameters: HMCParameters): """Build a HMC kernel. Parameters ---------- potential_fn A function that returns the potential energy of a chain at a given position. parameters A NamedTuple that contains the parameters of the kernel to be built. Returns ------- A kernel that takes a rng_key and a Pytree that contains the current state of the chain and that returns a new state of the chain along with information about the transition. """ step_size, num_integration_steps, inv_mass_matrix, divergence_threshold = parameters if inv_mass_matrix is None: raise ValueError( "Expected a value for `inv_mass_matrix`," " got None. Please specify a value when initializing" " the parameters or run the window adaptation." ) momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( inv_mass_matrix ) symplectic_integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) proposal_generator = hmc_proposal( symplectic_integrator, kinetic_energy_fn, step_size, num_integration_steps, divergence_threshold, ) kernel = base.hmc(momentum_generator, proposal_generator) return kernel
def test_dynamic_progressive_equal_recursive(): rng_key = jax.random.PRNGKey(23132) def potential_fn(x): return (1.0 - x[0])**2 + 1.5 * (x[1] - x[0]**2)**2 inverse_mass_matrix = jnp.asarray([[1.0, 0.5], [0.5, 1.25]]) momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inverse_mass_matrix) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) ( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, uturn_check_fn, ) = [ jax.jit(x) for x in ( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, uturn_check_fn, ) ] trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, divergence_threshold, ) buildtree_integrator = trajectory.dynamic_recursive_integration( integrator, kinetic_energy_fn, uturn_check_fn, divergence_threshold, ) for _ in range(50): ( rng_key, rng_direction, rng_tree_depth, rng_step_size, rng_position, rng_momentum, ) = jax.random.split(rng_key, 6) direction = jax.random.choice(rng_direction, jnp.array([-1, 1])) tree_depth = jax.random.choice(rng_tree_depth, np.arange(2, 5)) initial_state = integrators.new_integrator_state( potential_fn, jax.random.normal(rng_position, [2]), jax.random.normal(rng_momentum, [2]), ) step_size = jnp.abs(jax.random.normal(rng_step_size, [])) * 0.1 initial_energy = initial_state.potential_energy + kinetic_energy_fn( initial_state.position, initial_state.momentum) termination_state = new_criterion_state(initial_state, tree_depth) ( proposal0, trajectory0, _, is_diverging0, has_terminated0, _, ) = trajectory_integrator( rng_key, initial_state, direction, termination_state, 2**tree_depth, step_size, initial_energy, ) ( _, proposal1, trajectory1, is_diverging1, has_terminated1, ) = buildtree_integrator( rng_key, initial_state, direction, tree_depth, step_size, initial_energy, ) # Assert that the trajectory being built is the same jax.tree_multimap( functools.partial(np.testing.assert_allclose, rtol=1e-5), trajectory0, trajectory1, ) assert is_diverging0 == is_diverging1 assert has_terminated0 == has_terminated1 # We dont expect the proposal to be the same (even with the same PRNGKey # as the order of selection is different). but the property associate # with the full trajectory should be the same. np.testing.assert_allclose(proposal0.weight, proposal1.weight, rtol=1e-5) np.testing.assert_allclose(proposal0.sum_log_p_accept, proposal1.sum_log_p_accept, rtol=1e-5)