Ejemplo n.º 1
0
def test_gaussian_euclidean_ndim_invalid(shape):
    """Test Gaussian Euclidean Function returns correct function invalid ndim"""
    x = jnp.ones(shape=shape)

    with pytest.raises(ValueError) as e:
        metrics.gaussian_euclidean(x)
    assert "The mass matrix has the wrong number of dimensions" in str(e)
Ejemplo n.º 2
0
def kernel(potential_fn: Callable, parameters: HMCParameters) -> Callable:
    """Build a HMC kernel.

    Parameters
    ----------
    potential_fn
        A function that returns the potential energy of a chain at a given position.
    parameters
        A NamedTuple that contains the parameters of the kernel to be built.
    """
    step_size, num_integration_steps, inv_mass_matrix, divergence_threshold = parameters

    if inv_mass_matrix is None:
        raise ValueError("Expected a value for `inv_mass_matrix`,"
                         " got None. Please specify a value when initializing"
                         " the parameters or run the window adaptation.")

    momentum_generator, kinetic_energy_fn = metrics.gaussian_euclidean(
        inv_mass_matrix)
    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    proposal = proposals.hmc(integrator, step_size, num_integration_steps)
    kernel = base.hmc(
        proposal,
        momentum_generator,
        kinetic_energy_fn,
        divergence_threshold,
    )

    return kernel
Ejemplo n.º 3
0
def kernel(potential_fn: Callable, parameters: NUTSParameters) -> Callable:
    """Build an iterative NUTS kernel.

    Parameters
    ----------
    potential_fn
        A function that returns the potential energy of a chain at a given position. The potential energy
        is defined as minus the log-probability.
    parameters
        A NamedTuple that contains the parameters of the kernel to be built.

    """
    step_size, max_tree_depth, inv_mass_matrix, divergence_threshold = parameters

    if inv_mass_matrix is None:
        raise ValueError("Expected a value for `inv_mass_matrix`,"
                         " got None. Please specify a value when initializing"
                         " the parameters or run the window adaptation.")

    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inv_mass_matrix)
    symplectic_integrator = integrators.velocity_verlet(
        potential_fn, kinetic_energy_fn)
    proposal_generator = iterative_nuts_proposal(
        symplectic_integrator,
        kinetic_energy_fn,
        uturn_check_fn,
        step_size,
        max_tree_depth,
        divergence_threshold,
    )

    kernel = base.hmc(momentum_generator, proposal_generator)

    return kernel
Ejemplo n.º 4
0
def test_dynamic_progressive_integration_divergence(case):
    rng_key = jax.random.PRNGKey(0)

    def potential_fn(x):
        return jax.scipy.stats.norm.logpdf(x)

    step_size, should_diverge = case
    position = 1.0
    inverse_mass_matrix = jnp.array([1.0])

    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inverse_mass_matrix)

    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )

    # Initialize
    direction = 1
    initial_state = integrators.new_integrator_state(
        potential_fn, position, momentum_generator(rng_key, position))
    initial_energy = initial_state.potential_energy + kinetic_energy_fn(
        initial_state.position, initial_state.momentum)
    termination_state = new_criterion_state(initial_state, 10)
    max_num_steps = 100

    _, _, _, is_diverging, _, _ = trajectory_integrator(
        rng_key,
        initial_state,
        direction,
        termination_state,
        max_num_steps,
        step_size,
        initial_energy,
    )

    assert is_diverging.item() is should_diverge
Ejemplo n.º 5
0
def test_dynamic_progressive_expansion(case):
    rng_key = jax.random.PRNGKey(0)

    def potential_fn(x):
        return 0.5 * x ** 2

    step_size, should_diverge, should_turn, expected_doublings = case
    position = 0.0
    inverse_mass_matrix = jnp.array([1.0])

    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inverse_mass_matrix
    )

    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )

    expand = trajectory.dynamic_multiplicative_expansion(
        trajectory_integrator, uturn_check_fn, step_size
    )

    state = integrators.new_integrator_state(
        potential_fn, position, momentum_generator(rng_key, position)
    )
    energy = state.potential_energy + kinetic_energy_fn(state.position, state.momentum)
    initial_proposal = initial_proposal = proposal.Proposal(state, energy, 0.0)
    initial_termination_state = new_criterion_state(state, 10)

    _, _, step, is_diverging, has_terminated, is_turning = expand(
        rng_key, initial_proposal, initial_termination_state
    )

    assert is_diverging == should_diverge
    assert step == expected_doublings
    assert is_turning == should_turn
Ejemplo n.º 6
0
def test_gaussian_euclidean_dim_2():
    """Test Gaussian Euclidean Function with ndim 2"""
    inverse_mass_matrix = jnp.asarray([[1 / 9, 0], [0, 1 / 4]], dtype=DTYPE)
    momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix)

    arbitrary_position = jnp.asarray([12345, 23456], dtype=DTYPE)
    momentum_val = momentum(KEY, arbitrary_position)

    # 2 is square root inverse of 1/4
    # -0.20584235 is random value returned with random key
    expected_momentum_val = jnp.asarray([3, 2]) * jnp.asarray([-0.784766, 0.8564448])

    kinetic_energy_val = kinetic_energy(momentum_val)
    velocity = jnp.dot(inverse_mass_matrix, momentum_val)
    expected_kinetic_energy_val = 0.5 * jnp.matmul(velocity, momentum_val)

    assert pytest.approx(expected_momentum_val, momentum_val)
    assert kinetic_energy_val == expected_kinetic_energy_val
Ejemplo n.º 7
0
def test_gaussian_euclidean_dim_1():
    """Test Gaussian Euclidean Function with ndim 1"""
    inverse_mass_matrix = jnp.asarray([1 / 4], dtype=DTYPE)
    momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix)

    arbitrary_position = jnp.asarray([12345], dtype=DTYPE)
    momentum_val = momentum(KEY, arbitrary_position)

    # 2 is square root inverse of 1/4
    # -0.20584235 is random value returned with random key
    expected_momentum_val = 2 * -0.20584235

    kinetic_energy_val = kinetic_energy(momentum_val)
    velocity = inverse_mass_matrix * momentum_val
    expected_kinetic_energy_val = 0.5 * velocity * momentum_val

    assert momentum_val == expected_momentum_val
    assert kinetic_energy_val == expected_kinetic_energy_val
Ejemplo n.º 8
0
def test_is_iterative_turning(checkpoint_idxs, expected_turning):
    inverse_mass_matrix = jnp.ones(1)
    _, _, is_turning = gaussian_euclidean(inverse_mass_matrix)
    _, _, is_iterative_turning = iterative_uturn_numpyro(is_turning)

    momentum = 1.0
    momentum_sum = 3.0

    idx_min, idx_max = checkpoint_idxs
    momentum_ckpts = jnp.array([1.0, 2.0, 3.0, -2.0])
    momentum_sum_ckpts = jnp.array([2.0, 4.0, 4.0, -1.0])
    checkpoints = IterativeUTurnState(
        momentum_ckpts,
        momentum_sum_ckpts,
        idx_min,
        idx_max,
    )

    actual_turning = is_iterative_turning(checkpoints, momentum_sum, momentum)

    assert expected_turning == actual_turning
Ejemplo n.º 9
0
def kernel(potential_fn: Callable, parameters: HMCParameters):
    """Build a HMC kernel.

    Parameters
    ----------
    potential_fn
        A function that returns the potential energy of a chain at a given position.
    parameters
        A NamedTuple that contains the parameters of the kernel to be built.

    Returns
    -------
    A kernel that takes a rng_key and a Pytree that contains the current state
    of the chain and that returns a new state of the chain along with
    information about the transition.
    """
    step_size, num_integration_steps, inv_mass_matrix, divergence_threshold = parameters

    if inv_mass_matrix is None:
        raise ValueError(
            "Expected a value for `inv_mass_matrix`,"
            " got None. Please specify a value when initializing"
            " the parameters or run the window adaptation."
        )

    momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
        inv_mass_matrix
    )
    symplectic_integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    proposal_generator = hmc_proposal(
        symplectic_integrator,
        kinetic_energy_fn,
        step_size,
        num_integration_steps,
        divergence_threshold,
    )
    kernel = base.hmc(momentum_generator, proposal_generator)
    return kernel
Ejemplo n.º 10
0
def test_dynamic_progressive_equal_recursive():
    rng_key = jax.random.PRNGKey(23132)

    def potential_fn(x):
        return (1.0 - x[0])**2 + 1.5 * (x[1] - x[0]**2)**2

    inverse_mass_matrix = jnp.asarray([[1.0, 0.5], [0.5, 1.25]])
    momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_euclidean(
        inverse_mass_matrix)

    integrator = integrators.velocity_verlet(potential_fn, kinetic_energy_fn)
    (
        new_criterion_state,
        update_criterion_state,
        is_criterion_met,
    ) = termination.iterative_uturn_numpyro(uturn_check_fn)
    (
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        uturn_check_fn,
    ) = [
        jax.jit(x) for x in (
            integrator,
            kinetic_energy_fn,
            update_criterion_state,
            is_criterion_met,
            uturn_check_fn,
        )
    ]

    trajectory_integrator = trajectory.dynamic_progressive_integration(
        integrator,
        kinetic_energy_fn,
        update_criterion_state,
        is_criterion_met,
        divergence_threshold,
    )
    buildtree_integrator = trajectory.dynamic_recursive_integration(
        integrator,
        kinetic_energy_fn,
        uturn_check_fn,
        divergence_threshold,
    )

    for _ in range(50):
        (
            rng_key,
            rng_direction,
            rng_tree_depth,
            rng_step_size,
            rng_position,
            rng_momentum,
        ) = jax.random.split(rng_key, 6)
        direction = jax.random.choice(rng_direction, jnp.array([-1, 1]))
        tree_depth = jax.random.choice(rng_tree_depth, np.arange(2, 5))
        initial_state = integrators.new_integrator_state(
            potential_fn,
            jax.random.normal(rng_position, [2]),
            jax.random.normal(rng_momentum, [2]),
        )
        step_size = jnp.abs(jax.random.normal(rng_step_size, [])) * 0.1
        initial_energy = initial_state.potential_energy + kinetic_energy_fn(
            initial_state.position, initial_state.momentum)

        termination_state = new_criterion_state(initial_state, tree_depth)
        (
            proposal0,
            trajectory0,
            _,
            is_diverging0,
            has_terminated0,
            _,
        ) = trajectory_integrator(
            rng_key,
            initial_state,
            direction,
            termination_state,
            2**tree_depth,
            step_size,
            initial_energy,
        )

        (
            _,
            proposal1,
            trajectory1,
            is_diverging1,
            has_terminated1,
        ) = buildtree_integrator(
            rng_key,
            initial_state,
            direction,
            tree_depth,
            step_size,
            initial_energy,
        )
        # Assert that the trajectory being built is the same
        jax.tree_multimap(
            functools.partial(np.testing.assert_allclose, rtol=1e-5),
            trajectory0,
            trajectory1,
        )
        assert is_diverging0 == is_diverging1
        assert has_terminated0 == has_terminated1
        # We dont expect the proposal to be the same (even with the same PRNGKey
        # as the order of selection is different). but the property associate
        # with the full trajectory should be the same.
        np.testing.assert_allclose(proposal0.weight,
                                   proposal1.weight,
                                   rtol=1e-5)
        np.testing.assert_allclose(proposal0.sum_log_p_accept,
                                   proposal1.sum_log_p_accept,
                                   rtol=1e-5)