Exemplo n.º 1
0
    def build_kernel(logpdf: Callable, parameters: HMCParameters) -> Callable:
        """Builds the kernel that moves the chain from one point
        to the next.
        """

        potential = logpdf

        try:
            inverse_mass_matrix = parameters.inverse_mass_matrix
            num_integration_steps = parameters.num_integration_steps
            step_size = parameters.step_size
        except AttributeError:
            AttributeError(
                "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size."
            )

        momentum_generator, kinetic_energy = gaussian_euclidean_metric(
            inverse_mass_matrix, )
        integrator_step = integrator(potential, kinetic_energy)
        proposal = hmc_proposal(integrator_step, step_size,
                                num_integration_steps)
        kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                            potential)

        return kernel
Exemplo n.º 2
0
 def build_kernel(num_integration_steps, step_size, inverse_mass_matrix):
     momentum_generator, kinetic_energy = gaussian_euclidean_metric(
         inverse_mass_matrix,
     )
     integrator_step = self.integrator(potential, kinetic_energy)
     proposal = hmc_proposal(integrator_step, step_size, num_integration_steps)
     kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, potential)
     return kernel
Exemplo n.º 3
0
 def kernel_generator(step_size, inv_mass_matrix):
     momentum_generator, kinetic_energy = gaussian_euclidean_metric(
         inv_mass_matrix)
     integrator_step = velocity_verlet(potential_fn, kinetic_energy)
     proposal = hmc_proposal(integrator_step, step_size, 1)
     kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                         potential_fn)
     return kernel
Exemplo n.º 4
0
def test_find_reasonable_step_size():
    def potential_fn(x):
        return np.sum(0.5 * np.square(x))

    rng_key = jax.random.PRNGKey(0)

    init_position = np.array([3.0])
    inv_mass_matrix = np.array([1.0])

    init_state = hmc_init(init_position, potential_fn)
    momentum_generator, kinetic_energy = gaussian_euclidean_metric(inv_mass_matrix)
    integrator_step = velocity_verlet(potential_fn, kinetic_energy)

    # Test that the algorithm actually does something
    epsilon_1 = find_reasonable_step_size(
        rng_key,
        momentum_generator,
        kinetic_energy,
        potential_fn,
        integrator_step,
        init_state,
        1.0,
        0.95,
    )
    assert epsilon_1 != 1.0

    # Different target acceptance rate
    epsilon_3 = find_reasonable_step_size(
        rng_key,
        momentum_generator,
        kinetic_energy,
        potential_fn,
        integrator_step,
        init_state,
        1.0,
        0.05,
    )
    assert epsilon_3 != epsilon_1
Exemplo n.º 5
0
def stan_hmc_warmup(
    rng_key: jax.random.PRNGKey,
    logpdf: Callable,
    initial_state: HMCState,
    euclidean_metric: Callable,
    integrator_step: Callable,
    inital_step_size: float,
    path_length: float,
    num_steps: int,
    is_mass_matrix_diagonal=True,
) -> Tuple[HMCState, DualAveragingState, MassMatrixAdaptationState]:
    """ Warmup scheme for sampling procedures based on euclidean manifold HMC.
    The schedule and algorithms used match Stan's [1]_ as closely as possible.

    Unlike several other libraries, we separate the warmup and sampling phases
    explicitly. This ensure a better modularity; a change in the warmup does
    not affect the sampling. It also allows users to run their own warmup
    should they want to.

    Stan's warmup consists in the three following phases:

    1. A fast adaptation window where only the step size is adapted using
    Nesterov's dual averaging scheme to match a target acceptance rate.
    2. A succession of slow adapation windows (where the size of a window
    is double that of the previous window) where both the mass matrix and the step size
    are adapted. The mass matrix is recomputed at the end of each window; the step
    size is re-initialized to a "reasonable" value.
    3. A last fast adaptation window where only the step size is adapted.

    Arguments
    ---------

    Returns
    -------
    Tuple
        The current state of the chain, of the dual averaging scheme and mass matrix
        adaptation scheme.
    """

    n_dims = np.shape(initial_state.position)[-1]  # `position` is a 1D array

    # Initialize the mass matrix adaptation
    mm_init, mm_update, mm_final = mass_matrix_adaptation(
        is_mass_matrix_diagonal)
    mm_state = mm_init(n_dims)

    # Initialize the HMC transition kernel
    momentum_generator, kinetic_energy = euclidean_metric(
        mm_state.inverse_mass_matrix)

    # Find a first reasonable step size and initialize dual averaging
    step_size = find_reasonable_step_size(
        rng_key,
        momentum_generator,
        kinetic_energy,
        integrator_step,
        initial_state,
        inital_step_size,
    )
    da_init, da_update = dual_averaging()
    da_state = da_init(step_size)

    # initial kernel
    proposal = hmc_proposal(integrator_step, path_length, step_size)
    kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf)

    # Get warmup schedule
    schedule = warmup_schedule(num_steps)

    state = initial_state
    for i, window in enumerate(schedule):
        is_middle_window = (0 < i) & (i < (len(schedule) - 1))

        for step in range(window):
            _, rng_key = jax.random.split(rng_key)
            state, info = kernel(rng_key, state)

            da_state = da_update(info.acceptance_probability, da_state)
            step_size = np.exp(da_state.log_step_size)
            proposal = hmc_proposal(integrator_step, path_length, step_size)

            if is_middle_window:
                mm_state = mm_update(mm_state, state.position)

            kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                                logpdf)

        if is_middle_window:
            inverse_mass_matrix = mm_final(mm_state)
            momentum_generator, kinetic_energy = gaussian_euclidean_metric(
                inverse_mass_matrix)
            mm_state = mm_init(n_dims)
            step_size = find_reasonable_step_size(
                rng_key,
                momentum_generator,
                kinetic_energy,
                integrator_step,
                state,
                step_size,
            )
            da_state = da_init(step_size)
            proposal = hmc_proposal(integrator_step, path_length, step_size)
            kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                                logpdf)

    return state, da_state, mm_state