Exemplo n.º 1
0
    def build_kernel(logpdf: Callable, parameters: HMCParameters) -> Callable:
        """Builds the kernel that moves the chain from one point
        to the next.
        """

        potential = logpdf

        try:
            inverse_mass_matrix = parameters.inverse_mass_matrix
            num_integration_steps = parameters.num_integration_steps
            step_size = parameters.step_size
        except AttributeError:
            AttributeError(
                "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size."
            )

        momentum_generator, kinetic_energy = gaussian_euclidean_metric(
            inverse_mass_matrix, )
        integrator_step = integrator(potential, kinetic_energy)
        proposal = hmc_proposal(integrator_step, step_size,
                                num_integration_steps)
        kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                            potential)

        return kernel
Exemplo n.º 2
0
 def build_kernel(num_integration_steps, step_size, inverse_mass_matrix):
     momentum_generator, kinetic_energy = gaussian_euclidean_metric(
         inverse_mass_matrix,
     )
     integrator_step = self.integrator(potential, kinetic_energy)
     proposal = hmc_proposal(integrator_step, step_size, num_integration_steps)
     kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, potential)
     return kernel
Exemplo n.º 3
0
 def kernel_generator(step_size, inv_mass_matrix):
     momentum_generator, kinetic_energy = gaussian_euclidean_metric(
         inv_mass_matrix)
     integrator_step = velocity_verlet(potential_fn, kinetic_energy)
     proposal = hmc_proposal(integrator_step, step_size, 1)
     kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                         potential_fn)
     return kernel
Exemplo n.º 4
0
 def _new_hmc_kernel(step_size: float) -> Callable:
     """Return a HMC kernel that operates with the provided step size."""
     integrator = hmc_proposal(integrator_step, step_size, 1)
     kernel = hmc_kernel(integrator, momentum_generator, kinetic_energy,
                         potential_fn)
     return kernel
Exemplo n.º 5
0
def stan_hmc_warmup(
    rng_key: jax.random.PRNGKey,
    logpdf: Callable,
    initial_state: HMCState,
    euclidean_metric: Callable,
    integrator_step: Callable,
    inital_step_size: float,
    path_length: float,
    num_steps: int,
    is_mass_matrix_diagonal=True,
) -> Tuple[HMCState, DualAveragingState, MassMatrixAdaptationState]:
    """ Warmup scheme for sampling procedures based on euclidean manifold HMC.
    The schedule and algorithms used match Stan's [1]_ as closely as possible.

    Unlike several other libraries, we separate the warmup and sampling phases
    explicitly. This ensure a better modularity; a change in the warmup does
    not affect the sampling. It also allows users to run their own warmup
    should they want to.

    Stan's warmup consists in the three following phases:

    1. A fast adaptation window where only the step size is adapted using
    Nesterov's dual averaging scheme to match a target acceptance rate.
    2. A succession of slow adapation windows (where the size of a window
    is double that of the previous window) where both the mass matrix and the step size
    are adapted. The mass matrix is recomputed at the end of each window; the step
    size is re-initialized to a "reasonable" value.
    3. A last fast adaptation window where only the step size is adapted.

    Arguments
    ---------

    Returns
    -------
    Tuple
        The current state of the chain, of the dual averaging scheme and mass matrix
        adaptation scheme.
    """

    n_dims = np.shape(initial_state.position)[-1]  # `position` is a 1D array

    # Initialize the mass matrix adaptation
    mm_init, mm_update, mm_final = mass_matrix_adaptation(
        is_mass_matrix_diagonal)
    mm_state = mm_init(n_dims)

    # Initialize the HMC transition kernel
    momentum_generator, kinetic_energy = euclidean_metric(
        mm_state.inverse_mass_matrix)

    # Find a first reasonable step size and initialize dual averaging
    step_size = find_reasonable_step_size(
        rng_key,
        momentum_generator,
        kinetic_energy,
        integrator_step,
        initial_state,
        inital_step_size,
    )
    da_init, da_update = dual_averaging()
    da_state = da_init(step_size)

    # initial kernel
    proposal = hmc_proposal(integrator_step, path_length, step_size)
    kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf)

    # Get warmup schedule
    schedule = warmup_schedule(num_steps)

    state = initial_state
    for i, window in enumerate(schedule):
        is_middle_window = (0 < i) & (i < (len(schedule) - 1))

        for step in range(window):
            _, rng_key = jax.random.split(rng_key)
            state, info = kernel(rng_key, state)

            da_state = da_update(info.acceptance_probability, da_state)
            step_size = np.exp(da_state.log_step_size)
            proposal = hmc_proposal(integrator_step, path_length, step_size)

            if is_middle_window:
                mm_state = mm_update(mm_state, state.position)

            kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                                logpdf)

        if is_middle_window:
            inverse_mass_matrix = mm_final(mm_state)
            momentum_generator, kinetic_energy = gaussian_euclidean_metric(
                inverse_mass_matrix)
            mm_state = mm_init(n_dims)
            step_size = find_reasonable_step_size(
                rng_key,
                momentum_generator,
                kinetic_energy,
                integrator_step,
                state,
                step_size,
            )
            da_state = da_init(step_size)
            proposal = hmc_proposal(integrator_step, path_length, step_size)
            kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy,
                                logpdf)

    return state, da_state, mm_state
Exemplo n.º 6
0
def ehmc_warmup(
    rng_key: jax.random.PRNGKey,
    logpdf: Callable,
    initial_state: HMCState,
    momentum_generator: Callable,
    kinetic_energy: Callable,
    integrator_step: Callable,
    step_size: float,
    path_length: float,
    num_longest_batch: int,
):
    """Warmup scheme for empirical Hamiltonian Monte Carlo.

    We build a list of longest batches (path lengths) from which we will draw
    path lengths for the integration step during inference. This warmup is
    typically run after the hmc warmup.

    For that purpose, we build a custom integrator that implements algorithm
    2. in [1]_.

    References
    ----------
    .. [1]: Wu, Changye, Julien Stoehr, and Christian P. Robert. "Faster
            Hamiltonian Monte Carlo by learning leapfrog scale." arXiv preprint
            arXiv:1810.04449 (2018).
    """

    # Build the warmup kernel

    step = integrator_step(logpdf)
    longest_batch_step = longest_batch_before_turn(step)

    def longest_batch_integrator(rng_key: jax.random.PRNGKey,
                                 integrator_state):
        """The integrator state that is iterated over is the standard
        IntegratorState plus the longest batch length. Here we bump in a
        limitation of python < 3.7: it is not possible to subclass the
        IntegratorState named tuple to add a field. We need to find a solution
        as we may often need to subclass the base named tuple.
        """
        position, momentum, log_prob, log_prob_grad, batch_length = longest_batch_step(
            integrator_state.position,
            integrator_state.momentum,
            integrator_state.step_size,
            integrator_state.path_length,
        )
        if batch_length < path_length:
            position, momentum, log_prob, log_prob_grad = step(
                position,
                momentum,
                log_prob,
                log_prob_grad,
                step_size,
                path_length - batch_length,
            )
        return momentum, position, log_prob, log_prob_grad, batch_length

    ehmc_warmup_kernel = hmc_kernel(longest_batch_integrator,
                                    momentum_generator, kinetic_energy, logpdf)

    # Run the kernel and return an array of longest batch lengths

    def warmup_update(state, key):
        hmc_state, _ = state
        new_hmc_state, new_hmc_info = ehmc_warmup_kernel(key, hmc_state)
        _, _, _, _, batch_length = new_hmc_info.integrator_step
        return (new_hmc_state, new_hmc_info), batch_length

    keys = jax.random.split(rng_key, num_longest_batch)
    state, batch_lengths = jax.lax.scan(
        warmup_update,
        (initial_state, None),
        keys,
    )
    hmc_warmup_state, _ = state

    return hmc_warmup_state, batch_lengths