def test_dynamic_progressive_integration_divergence(case): rng_key = jax.random.PRNGKey(0) def potential_fn(x): return jax.scipy.stats.norm.logpdf(x) step_size, should_diverge = case position = 1.0 inverse_mass_matrix = jnp.array([1.0]) momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inverse_mass_matrix) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, divergence_threshold, ) # Initialize direction = 1 initial_state = integrators.new_integrator_state( potential_fn, position, momentum_generator(rng_key, position)) initial_energy = initial_state.potential_energy + kinetic_energy_fn( initial_state.position, initial_state.momentum) termination_state = new_criterion_state(initial_state, 10) max_num_steps = 100 _, _, _, is_diverging, _, _ = trajectory_integrator( rng_key, initial_state, direction, termination_state, max_num_steps, step_size, initial_energy, ) assert is_diverging.item() is should_diverge
def test_dynamic_progressive_expansion(case): rng_key = jax.random.PRNGKey(0) def potential_fn(x): return 0.5 * x ** 2 step_size, should_diverge, should_turn, expected_doublings = case position = 0.0 inverse_mass_matrix = jnp.array([1.0]) momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inverse_mass_matrix ) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, divergence_threshold, ) expand = trajectory.dynamic_multiplicative_expansion( trajectory_integrator, uturn_check_fn, step_size ) state = integrators.new_integrator_state( potential_fn, position, momentum_generator(rng_key, position) ) energy = state.potential_energy + kinetic_energy_fn(state.position, state.momentum) initial_proposal = initial_proposal = proposal.Proposal(state, energy, 0.0) initial_termination_state = new_criterion_state(state, 10) _, _, step, is_diverging, has_terminated, is_turning = expand( rng_key, initial_proposal, initial_termination_state ) assert is_diverging == should_diverge assert step == expected_doublings assert is_turning == should_turn
def test_is_iterative_turning(checkpoint_idxs, expected_turning): inverse_mass_matrix = jnp.ones(1) _, _, is_turning = gaussian_euclidean(inverse_mass_matrix) _, _, is_iterative_turning = iterative_uturn_numpyro(is_turning) momentum = 1.0 momentum_sum = 3.0 idx_min, idx_max = checkpoint_idxs momentum_ckpts = jnp.array([1.0, 2.0, 3.0, -2.0]) momentum_sum_ckpts = jnp.array([2.0, 4.0, 4.0, -1.0]) checkpoints = IterativeUTurnState( momentum_ckpts, momentum_sum_ckpts, idx_min, idx_max, ) actual_turning = is_iterative_turning(checkpoints, momentum_sum, momentum) assert expected_turning == actual_turning
def test_dynamic_progressive_equal_recursive(): rng_key = jax.random.PRNGKey(23132) def potential_fn(x): return (1.0 - x[0])**2 + 1.5 * (x[1] - x[0]**2)**2 inverse_mass_matrix = jnp.asarray([[1.0, 0.5], [0.5, 1.25]]) momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean( inverse_mass_matrix) integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn) ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) ( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, uturn_check_fn, ) = [ jax.jit(x) for x in ( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, uturn_check_fn, ) ] trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy_fn, update_criterion_state, is_criterion_met, divergence_threshold, ) buildtree_integrator = trajectory.dynamic_recursive_integration( integrator, kinetic_energy_fn, uturn_check_fn, divergence_threshold, ) for _ in range(50): ( rng_key, rng_direction, rng_tree_depth, rng_step_size, rng_position, rng_momentum, ) = jax.random.split(rng_key, 6) direction = jax.random.choice(rng_direction, jnp.array([-1, 1])) tree_depth = jax.random.choice(rng_tree_depth, np.arange(2, 5)) initial_state = integrators.new_integrator_state( potential_fn, jax.random.normal(rng_position, [2]), jax.random.normal(rng_momentum, [2]), ) step_size = jnp.abs(jax.random.normal(rng_step_size, [])) * 0.1 initial_energy = initial_state.potential_energy + kinetic_energy_fn( initial_state.position, initial_state.momentum) termination_state = new_criterion_state(initial_state, tree_depth) ( proposal0, trajectory0, _, is_diverging0, has_terminated0, _, ) = trajectory_integrator( rng_key, initial_state, direction, termination_state, 2**tree_depth, step_size, initial_energy, ) ( _, proposal1, trajectory1, is_diverging1, has_terminated1, ) = buildtree_integrator( rng_key, initial_state, direction, tree_depth, step_size, initial_energy, ) # Assert that the trajectory being built is the same jax.tree_multimap( functools.partial(np.testing.assert_allclose, rtol=1e-5), trajectory0, trajectory1, ) assert is_diverging0 == is_diverging1 assert has_terminated0 == has_terminated1 # We dont expect the proposal to be the same (even with the same PRNGKey # as the order of selection is different). but the property associate # with the full trajectory should be the same. np.testing.assert_allclose(proposal0.weight, proposal1.weight, rtol=1e-5) np.testing.assert_allclose(proposal0.sum_log_p_accept, proposal1.sum_log_p_accept, rtol=1e-5)
def iterative_nuts_proposal( integrator: Callable, kinetic_energy: Callable, uturn_check_fn: Callable, step_size: float, max_num_expansions: int = 10, divergence_threshold: float = 1000, ) -> Callable: """Iterative NUTS algorithm. This algorithm is an iteration on the original NUTS algorithm [1]_ with two major differences: - We do not use slice samplig but multinomial sampling for the proposal [2]_; - The trajectory expansion is not recursive but iterative [3,4]_. The implementation can seem unusual for those familiar with similar algorithms. Indeed, we do not conceptualize the trajectory construction as building a tree. We feel that the tree lingo, inherited from the recursive version, is unnecessarily complicated and hides the more general concepts on which the NUTS algorithm is built. NUTS, in essence, consists in sampling a trajectory by iteratively choosing a direction at random and integrating in this direction a number of times that doubles at every step. From this trajectory we continuously sample a proposal. When the trajectory turns on itself or when we have reached the maximum trajectory length we return the current proposal. Parameters ---------- integrator Symplectic integrator used to build the trajectory step by step. kinetic_energy Function that computes the kinetic energy. uturn_check_fn: Function that determines whether the trajectory is turning on itself (metric-dependant). step_size Size of the integration step. max_num_expansions The number of sub-trajectory samples we take to build the trajectory. divergence_threshold Threshold above which we say that there is a divergence. Returns ------- A kernel that generates a new chain state and information about the transition. References ---------- .. [1]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623. .. [2]: Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017). .. [3]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019). .. [4]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020). """ ( new_criterion_state, update_criterion_state, is_criterion_met, ) = termination.iterative_uturn_numpyro(uturn_check_fn) trajectory_integrator = trajectory.dynamic_progressive_integration( integrator, kinetic_energy, update_criterion_state, is_criterion_met, divergence_threshold, ) expand = trajectory.dynamic_multiplicative_expansion( trajectory_integrator, uturn_check_fn, step_size, max_num_expansions, ) def _compute_energy(state: integrators.IntegratorState) -> float: energy = state.potential_energy + kinetic_energy( state.position, state.momentum) return energy def propose(rng_key, initial_state: integrators.IntegratorState): criterion_state = new_criterion_state(initial_state, max_num_expansions) initial_proposal = proposal.Proposal(initial_state, _compute_energy(initial_state), 0.0) sampled_proposal, *info = expand( rng_key, initial_proposal, criterion_state, ) trajectory, num_doublings, is_diverging, has_terminated, is_turning = info info = NUTSInfo( initial_state.momentum, is_diverging, has_terminated | is_turning, sampled_proposal.energy, trajectory.leftmost_state, trajectory.rightmost_state, num_doublings, trajectory.num_states, ) return sampled_proposal.state, info return propose