コード例 #1
0
    def init_kernel(init_samples,
                    num_warmup_steps,
                    step_size=1.0,
                    num_steps=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    diag_mass=True,
                    target_accept_prob=0.8,
                    run_warmup=True,
                    rng=PRNGKey(0)):
        step_size = float(step_size)
        nonlocal trajectory_length, momentum_generator, wa_update

        if num_steps is None:
            trajectory_length = 2 * math.pi
        else:
            trajectory_length = num_steps * step_size

        z = init_samples
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup_steps,
            find_reasonable_step_size=find_reasonable_ss,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            diag_mass=diag_mass,
            target_accept_prob=target_accept_prob)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.inverse_mass_matrix, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, 0, 0.,
                             wa_state.step_size, wa_state.inverse_mass_matrix,
                             rng_hmc)

        if run_warmup:
            hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update,
                                     (hmc_state, wa_state))
            return hmc_state
        else:
            return hmc_state, wa_state, warmup_update
コード例 #2
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2*math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup_steps: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
            step size is done at the beginning of each adaptation window to achieve
            `target_acceptance_prob`.
        :param jax.random.PRNGKey rng: random key to be used as the source of
            randomness.
        """
        step_size = float(step_size)
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = float(trajectory_length)
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size,
                                     potential_fn, kinetic_fn,
                                     momentum_generator)

        wa_init, wa_update = warmup_adapter(num_warmup,
                                            adapt_step_size=adapt_step_size,
                                            adapt_mass_matrix=adapt_mass_matrix,
                                            dense_mass=dense_mass,
                                            target_accept_prob=target_accept_prob,
                                            find_reasonable_step_size=find_reasonable_ss)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
                             wa_state, rng_hmc)

        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = sample_kernel(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
        return hmc_state
コード例 #3
0
def test_warmup_adapter(jitted):
    def find_reasonable_step_size(m_inv, z, rng, step_size):
        return np.where(step_size < 1, step_size * 4, step_size / 4)

    num_steps = 150
    adaptation_schedule = build_adaptation_schedule(num_steps)
    init_step_size = 1.
    mass_matrix_size = 3

    wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
    wa_update = jit(wa_update) if jitted else wa_update

    rng = random.PRNGKey(0)
    z = np.ones(3)
    wa_state = wa_init(z, rng, init_step_size, mass_matrix_size=mass_matrix_size)
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert step_size == find_reasonable_step_size(inverse_mass_matrix, z, rng, init_step_size)
    assert_allclose(inverse_mass_matrix, np.ones(mass_matrix_size))
    assert window_idx == 0

    window = adaptation_schedule[0]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 1
    # step_size is decreased because accept_prob < target_accept_prob
    assert step_size < last_step_size
    # inverse_mass_matrix does not change at the end of the first window
    assert_allclose(inverse_mass_matrix, np.ones(mass_matrix_size))

    window = adaptation_schedule[1]
    window_len = window.end - window.start
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len, 2 * z, wa_state)
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 2
    # step_size is increased because accept_prob > target_accept_prob
    assert step_size > last_step_size
    # Verifies that inverse_mass_matrix changes at the end of the second window.
    # Because z_flat is constant during the second window, covariance will be 0
    # and only regularize_term of welford scheme is involved.
    # This also verifies that z_flat terms in the first window does not affect
    # the second window.
    welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5))
    assert_allclose(inverse_mass_matrix,
                    np.full((mass_matrix_size,), welford_regularize_term),
                    atol=1e-7)

    window = adaptation_schedule[2]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8, t * z, wa_state)
    last_step_size = step_size
    step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
    assert window_idx == 3
    # during the last window, because target_accept_prob=0.8,
    # log_step_size will be equal to the constant prox_center=log(10*last_step_size)
    assert_allclose(step_size, last_step_size * 10)
    # Verifies that inverse_mass_matrix does not change during the last window
    # despite z_flat changes w.r.t time t,
    assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)