Пример #1
0
def test_dynamic_progressive_integration_divergence(case):
    rng_key = jax.random.PRNGKey(0)

    def potential_fn(x):
        return jax.scipy.stats.norm.logpdf(x)

    step_size, should_diverge = case
    position = 1.0
    inverse_mass_matrix = jnp.array([1.0])

    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inverse_mass_matrix)

    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )

    # Initialize
    direction = 1
    initial_state = integrators.new_integrator_state(
        potential_fn, position, momentum_generator(rng_key, position))
    initial_energy = initial_state.potential_energy + kinetic_energy_fn(
        initial_state.position, initial_state.momentum)
    termination_state = new_criterion_state(initial_state, 10)
    max_num_steps = 100

    _, _, _, is_diverging, _, _ = trajectory_integrator(
        rng_key,
        initial_state,
        direction,
        termination_state,
        max_num_steps,
        step_size,
        initial_energy,
    )

    assert is_diverging.item() is should_diverge
Пример #2
0
def test_dynamic_progressive_expansion(case):
    rng_key = jax.random.PRNGKey(0)

    def potential_fn(x):
        return 0.5 * x ** 2

    step_size, should_diverge, should_turn, expected_doublings = case
    position = 0.0
    inverse_mass_matrix = jnp.array([1.0])

    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inverse_mass_matrix
    )

    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )

    expand = trajectory.dynamic_multiplicative_expansion(
        trajectory_integrator, uturn_check_fn, step_size
    )

    state = integrators.new_integrator_state(
        potential_fn, position, momentum_generator(rng_key, position)
    )
    energy = state.potential_energy + kinetic_energy_fn(state.position, state.momentum)
    initial_proposal = initial_proposal = proposal.Proposal(state, energy, 0.0)
    initial_termination_state = new_criterion_state(state, 10)

    _, _, step, is_diverging, has_terminated, is_turning = expand(
        rng_key, initial_proposal, initial_termination_state
    )

    assert is_diverging == should_diverge
    assert step == expected_doublings
    assert is_turning == should_turn
Пример #3
0
def test_is_iterative_turning(checkpoint_idxs, expected_turning):
    inverse_mass_matrix = jnp.ones(1)
    _, _, is_turning = gaussian_euclidean(inverse_mass_matrix)
    _, _, is_iterative_turning = iterative_uturn_numpyro(is_turning)

    momentum = 1.0
    momentum_sum = 3.0

    idx_min, idx_max = checkpoint_idxs
    momentum_ckpts = jnp.array([1.0, 2.0, 3.0, -2.0])
    momentum_sum_ckpts = jnp.array([2.0, 4.0, 4.0, -1.0])
    checkpoints = IterativeUTurnState(
        momentum_ckpts,
        momentum_sum_ckpts,
        idx_min,
        idx_max,
    )

    actual_turning = is_iterative_turning(checkpoints, momentum_sum, momentum)

    assert expected_turning == actual_turning
Пример #4
0
def test_dynamic_progressive_equal_recursive():
    rng_key = jax.random.PRNGKey(23132)

    def potential_fn(x):
        return (1.0 - x[0])**2 + 1.5 * (x[1] - x[0]**2)**2

    inverse_mass_matrix = jnp.asarray([[1.0, 0.5], [0.5, 1.25]])
    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inverse_mass_matrix)

    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)
    (
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        uturn_check_fn,
    ) = [
        jax.jit(x) for x in (
            integrator,
            kinetic_energy_fn,
            update_criterion_state,
            is_criterion_met,
            uturn_check_fn,
        )
    ]

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )
    buildtree_integrator = trajectory.dynamic_recursive_integration(
        integrator,
        kinetic_energy_fn,
        uturn_check_fn,
        divergence_threshold,
    )

    for _ in range(50):
        (
            rng_key,
            rng_direction,
            rng_tree_depth,
            rng_step_size,
            rng_position,
            rng_momentum,
        ) = jax.random.split(rng_key, 6)
        direction = jax.random.choice(rng_direction, jnp.array([-1, 1]))
        tree_depth = jax.random.choice(rng_tree_depth, np.arange(2, 5))
        initial_state = integrators.new_integrator_state(
            potential_fn,
            jax.random.normal(rng_position, [2]),
            jax.random.normal(rng_momentum, [2]),
        )
        step_size = jnp.abs(jax.random.normal(rng_step_size, [])) * 0.1
        initial_energy = initial_state.potential_energy + kinetic_energy_fn(
            initial_state.position, initial_state.momentum)

        termination_state = new_criterion_state(initial_state, tree_depth)
        (
            proposal0,
            trajectory0,
            _,
            is_diverging0,
            has_terminated0,
            _,
        ) = trajectory_integrator(
            rng_key,
            initial_state,
            direction,
            termination_state,
            2**tree_depth,
            step_size,
            initial_energy,
        )

        (
            _,
            proposal1,
            trajectory1,
            is_diverging1,
            has_terminated1,
        ) = buildtree_integrator(
            rng_key,
            initial_state,
            direction,
            tree_depth,
            step_size,
            initial_energy,
        )
        # Assert that the trajectory being built is the same
        jax.tree_multimap(
            functools.partial(np.testing.assert_allclose, rtol=1e-5),
            trajectory0,
            trajectory1,
        )
        assert is_diverging0 == is_diverging1
        assert has_terminated0 == has_terminated1
        # We dont expect the proposal to be the same (even with the same PRNGKey
        # as the order of selection is different). but the property associate
        # with the full trajectory should be the same.
        np.testing.assert_allclose(proposal0.weight,
                                   proposal1.weight,
                                   rtol=1e-5)
        np.testing.assert_allclose(proposal0.sum_log_p_accept,
                                   proposal1.sum_log_p_accept,
                                   rtol=1e-5)
Пример #5
0
def iterative_nuts_proposal(
    integrator: Callable,
    kinetic_energy: Callable,
    uturn_check_fn: Callable,
    step_size: float,
    max_num_expansions: int = 10,
    divergence_threshold: float = 1000,
) -> Callable:
    """Iterative NUTS algorithm.

    This algorithm is an iteration on the original NUTS algorithm [1]_ with two major differences:
    - We do not use slice samplig but multinomial sampling for the proposal [2]_;
    - The trajectory expansion is not recursive but iterative [3,4]_.

    The implementation can seem unusual for those familiar with similar
    algorithms. Indeed, we do not conceptualize the trajectory construction as
    building a tree. We feel that the tree lingo, inherited from the recursive
    version, is unnecessarily complicated and hides the more general concepts
    on which the NUTS algorithm is built.

    NUTS, in essence, consists in sampling a trajectory by iteratively choosing
    a direction at random and integrating in this direction a number of times
    that doubles at every step. From this trajectory we continuously sample a
    proposal. When the trajectory turns on itself or when we have reached the
    maximum trajectory length we return the current proposal.

    Parameters
    ----------
    integrator
        Symplectic integrator used to build the trajectory step by step.
    kinetic_energy
        Function that computes the kinetic energy.
    uturn_check_fn:
        Function that determines whether the trajectory is turning on itself (metric-dependant).
    step_size
        Size of the integration step.
    max_num_expansions
        The number of sub-trajectory samples we take to build the trajectory.
    divergence_threshold
        Threshold above which we say that there is a divergence.

    Returns
    -------
    A kernel that generates a new chain state and information about the transition.

    References
    ----------
    .. [1]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
    .. [2]: Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017).
    .. [3]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019).
    .. [4]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
    """
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )

    expand = trajectory.dynamic_multiplicative_expansion(
        trajectory_integrator,
        uturn_check_fn,
        step_size,
        max_num_expansions,
    )

    def _compute_energy(state: integrators.IntegratorState) -> float:
        energy = state.potential_energy + kinetic_energy(
            state.position, state.momentum)
        return energy

    def propose(rng_key, initial_state: integrators.IntegratorState):
        criterion_state = new_criterion_state(initial_state,
                                              max_num_expansions)
        initial_proposal = proposal.Proposal(initial_state,
                                             _compute_energy(initial_state),
                                             0.0)

        sampled_proposal, *info = expand(
            rng_key,
            initial_proposal,
            criterion_state,
        )
        trajectory, num_doublings, is_diverging, has_terminated, is_turning = info

        info = NUTSInfo(
            initial_state.momentum,
            is_diverging,
            has_terminated | is_turning,
            sampled_proposal.energy,
            trajectory.leftmost_state,
            trajectory.rightmost_state,
            num_doublings,
            trajectory.num_states,
        )

        return sampled_proposal.state, info

    return propose