示例#1
0
def _leaf_idx_to_ckpt_idxs(n):
    # computes the number of non-zero bits except the last bit
    # e.g. 6 -> 2, 7 -> 2, 13 -> 2
    _, idx_max = while_loop(lambda nc: nc[0] > 0, lambda nc:
                            (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0))
    # computes the number of contiguous last non-zero bits
    # e.g. 6 -> 0, 7 -> 3, 13 -> 1
    _, num_subtrees = while_loop(lambda nc: (nc[0] & 1) != 0, lambda nc:
                                 (nc[0] >> 1, nc[1] + 1), (n, 0))
    idx_min = idx_max - num_subtrees + 1
    return idx_min, idx_max
示例#2
0
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix,
                              position, rng, init_step_size):
    """
    Finds a reasonable step size by tuning `init_step_size`. This function is used
    to avoid working with a too large or too small step size in HMC.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param position: Current position of the particle.
    :param jax.random.PRNGKey rng: Random key to be used as the source of randomness.
    :param float init_step_size: Initial step size to be tuned.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = np.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)
    tiny = np.finfo(get_dtype(init_step_size)).tiny

    def _body_fn(state):
        step_size, _, direction, rng = state
        rng, rng_momentum = random.split(rng)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0 ** direction) * step_size
        r = momentum_generator(inverse_mass_matrix, rng_momentum)
        _, r_new, potential_energy_new, _ = vv_update(step_size,
                                                      inverse_mass_matrix,
                                                      (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = np.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng

    def _cond_fn(state):
        step_size, last_direction, direction, _ = state
        # condition to run only if step_size is not so small or we are not decreasing step_size
        not_small_step_size_cond = (step_size > tiny) | (direction >= 0)
        return not_small_step_size_cond & ((last_direction == 0) | (direction == last_direction))

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng))
    return step_size
示例#3
0
    def _find_valid_params(rng_key_):
        _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None))
        # Early return if valid params found.
        if not_jax_tracer(is_valid):
            if device_get(is_valid):
                return prototype_params, is_valid

        _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state)
        return init_params, is_valid
示例#4
0
def build_tree(verlet_update,
               kinetic_fn,
               verlet_state,
               inverse_mass_matrix,
               step_size,
               rng,
               max_delta_energy=1000.,
               max_tree_depth=10):
    """
    **References:**
    [1] `The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo`,
    Matthew D. Hoffman, Andrew Gelman
    [2] `A Conceptual Introduction to Hamiltonian Monte Carlo`,
    Michael Betancourt
    """
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]),
                       dtype=inverse_mass_matrix.dtype)
    r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]),
                           dtype=inverse_mass_matrix.dtype)

    tree = _TreeInfo(z,
                     r,
                     z_grad,
                     z,
                     r,
                     z_grad,
                     z,
                     potential_energy,
                     z_grad,
                     depth=0,
                     weight=0.,
                     r_sum=r,
                     turning=False,
                     diverging=False,
                     sum_accept_probs=0.,
                     num_proposals=0)

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(tree, verlet_update, kinetic_fn,
                            inverse_mass_matrix, step_size, going_right,
                            doubling_key, energy_current, max_delta_energy,
                            r_ckpts, r_sum_ckpts)
        return tree, key

    state = (tree, rng)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree
示例#5
0
    def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den):
        conc = jnp.broadcast_to(conc, shape)
        eig = jnp.broadcast_to(eig, shape)
        b0 = jnp.broadcast_to(b0, shape)
        eigmin = jnp.broadcast_to(eigmin, shape)
        phi_den = jnp.broadcast_to(phi_den, shape)

        def update_fn(curr):
            i, done, phi, key = curr
            phi_key, key = random.split(key)
            accept_key, acg_key, phi_key = random.split(phi_key, 3)

            x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape)
            x /= jnp.linalg.norm(
                x, axis=1, keepdims=True
            )  # Angular Central Gaussian distribution

            lf = (
                conc[:, :1] * (x[:, :1] - 1)
                + eigmin
                + log_I1(
                    0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2)
                ).squeeze(0)
                - phi_den
            )
            assert lf.shape == shape

            lg_inv = (
                1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True))
            )
            assert lg_inv.shape == lf.shape

            accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv)

            phi = jnp.where(accepted, x, phi)
            return PhiMarginalState(i + 1, done | accepted, phi, key)

        def cond_fn(curr):
            return jnp.bitwise_and(
                curr.i < SineBivariateVonMises.max_sample_iter,
                jnp.logical_not(jnp.all(curr.done)),
            )

        phi_state = while_loop(
            cond_fn,
            update_fn,
            PhiMarginalState(
                i=jnp.array(0),
                done=jnp.zeros(shape, dtype=bool),
                phi=jnp.empty(shape, dtype=float),
                key=rng_key,
            ),
        )
        return PhiMarginalState(
            phi_state.i, phi_state.done, phi_state.phi, phi_state.key
        )
示例#6
0
def _iterative_build_subtree(depth, vv_update, kinetic_fn, z, r, z_grad,
                             inverse_mass_matrix, step_size, going_right, rng,
                             energy_current, max_delta_energy, r_ckpts,
                             r_sum_ckpts):
    max_num_proposals = 2**depth

    def _cond_fn(state):
        tree, turning, _, _, _ = state
        return (tree.num_proposals <
                max_num_proposals) & ~turning & ~tree.diverging

    def _body_fn(state):
        current_tree, _, r_ckpts, r_sum_ckpts, rng = state
        rng, transition_rng = random.split(rng)
        z, r, z_grad = _get_leaf(current_tree, going_right)
        new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad,
                                   inverse_mass_matrix, step_size, going_right,
                                   energy_current, max_delta_energy)
        new_tree = _combine_tree(current_tree,
                                 new_leaf,
                                 inverse_mass_matrix,
                                 going_right,
                                 transition_rng,
                                 biased_transition=False)

        leaf_idx = current_tree.num_proposals
        ckpt_idx_min, ckpt_idx_max = _leaf_idx_to_ckpt_idxs(leaf_idx)
        r, _ = ravel_pytree(new_leaf.r_right)
        r_sum, _ = ravel_pytree(new_tree.r_sum)
        # we update checkpoints when leaf_idx is even
        r_ckpts, r_sum_ckpts = cond(
            leaf_idx % 2 == 0, (r_ckpts, r_sum_ckpts), lambda x:
            (index_update(x[0], ckpt_idx_max, r),
             index_update(x[1], ckpt_idx_max, r_sum)), (r_ckpts, r_sum_ckpts),
            lambda x: x)

        turning = _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts,
                                        r_sum_ckpts, ckpt_idx_min,
                                        ckpt_idx_max)
        return new_tree, turning, r_ckpts, r_sum_ckpts, rng

    basetree = _build_basetree(vv_update, kinetic_fn, z, r, z_grad,
                               inverse_mass_matrix, step_size, going_right,
                               energy_current, max_delta_energy)
    r_init, _ = ravel_pytree(basetree.r_left)
    r_ckpts = index_update(r_ckpts, 0, r_init)
    r_sum_ckpts = index_update(r_sum_ckpts, 0, r_init)

    tree, turning, _, _, _ = while_loop(
        _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng))
    # update depth and turning condition
    return _TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad, tree.z_right,
                     tree.r_right, tree.z_right_grad, tree.z_proposal,
                     tree.z_proposal_pe, tree.z_proposal_grad, depth,
                     tree.weight, tree.r_sum, turning, tree.diverging,
                     tree.sum_accept_probs, tree.num_proposals)
示例#7
0
def _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts, idx_min, idx_max):
    def _body_fn(state):
        i, _ = state
        subtree_r_sum = r_sum - r_sum_ckpts[i] + r_ckpts[i]
        return i - 1, _is_turning(inverse_mass_matrix, r_ckpts[i], r, subtree_r_sum)

    _, turning = while_loop(lambda it: (it[0] >= idx_min) & ~it[1],
                            _body_fn,
                            (idx_max, False))
    return turning
示例#8
0
def _leaf_idx_to_ckpt_idxs(n):
    # computes the number of non-zero bits except the last bit
    # e.g. 6 -> 2, 7 -> 2, 13 -> 2
    _, idx_max = while_loop(lambda nc: nc[0] > 0, lambda nc:
                            (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0))
    # computes the number of contiguous last non-zero bits
    # e.g. 6 -> 0, 7 -> 3, 13 -> 1
    _, num_subtrees = while_loop(lambda nc: (nc[0] & 1) != 0, lambda nc:
                                 (nc[0] >> 1, nc[1] + 1), (n, 0))
    # TODO: explore the potential of setting idx_min=0 to allow more turning checks
    # It will be useful in case: e.g. assume a tree 0 -> 7 is a circle,
    # subtrees 0 -> 3, 4 -> 7 are half-circles, which two leaves might not
    # satisfy turning condition;
    # the full tree 0 -> 7 is a circle, which two leaves might also not satisfy
    # turning condition;
    # however, we can check the turning condition of the subtree 0 -> 5, which
    # likely satisfies turning condition because its trajectory 3/4 of a circle.
    # XXX: make sure that detailed balance is satisfied if we follow this direction
    idx_min = idx_max - num_subtrees + 1
    return idx_min, idx_max
示例#9
0
def _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn,
                             inverse_mass_matrix, step_size, going_right, rng_key,
                             energy_current, max_delta_energy, r_ckpts, r_sum_ckpts):
    max_num_proposals = 2 ** prototype_tree.depth

    def _cond_fn(state):
        tree, turning, _, _, _ = state
        return (tree.num_proposals < max_num_proposals) & ~turning & ~tree.diverging

    def _body_fn(state):
        current_tree, _, r_ckpts, r_sum_ckpts, rng_key = state
        rng_key, transition_rng_key = random.split(rng_key)
        # If we are going to the right, start from the right leaf of the current tree.
        z, r, z_grad = _get_leaf(current_tree, going_right)
        new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size,
                                   going_right, energy_current, max_delta_energy)
        new_tree = cond(current_tree.num_proposals == 0,
                        new_leaf,
                        identity,
                        (current_tree, new_leaf, inverse_mass_matrix, going_right, transition_rng_key),
                        lambda x: _combine_tree(*x, False))

        leaf_idx = current_tree.num_proposals
        # NB: in the special case leaf_idx=0, ckpt_idx_min=1 and ckpt_idx_max=0,
        # the following logic is still valid for that case
        ckpt_idx_min, ckpt_idx_max = _leaf_idx_to_ckpt_idxs(leaf_idx)
        r, unravel_fn = ravel_pytree(new_leaf.r_right)
        r_sum, _ = ravel_pytree(new_tree.r_sum)
        # we update checkpoints when leaf_idx is even
        r_ckpts, r_sum_ckpts = cond(leaf_idx % 2 == 0,
                                    (r_ckpts, r_sum_ckpts),
                                    lambda x: (index_update(x[0], ckpt_idx_max, r),
                                               index_update(x[1], ckpt_idx_max, r_sum)),
                                    (r_ckpts, r_sum_ckpts),
                                    identity)

        turning = _is_iterative_turning(inverse_mass_matrix, new_leaf.r_right, r_sum,
                                        r_ckpts, r_sum_ckpts,
                                        ckpt_idx_min, ckpt_idx_max, unravel_fn)
        return new_tree, turning, r_ckpts, r_sum_ckpts, rng_key

    basetree = prototype_tree._replace(num_proposals=0)

    tree, turning, _, _, _ = while_loop(
        _cond_fn,
        _body_fn,
        (basetree, False, r_ckpts, r_sum_ckpts, rng_key)
    )
    # update depth and turning condition
    return TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad,
                    tree.z_right, tree.r_right, tree.z_right_grad,
                    tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy,
                    prototype_tree.depth, tree.weight, tree.r_sum, turning, tree.diverging,
                    tree.sum_accept_probs, tree.num_proposals)
示例#10
0
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key,
               max_delta_energy=1000., max_tree_depth=10):
    """
    Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman
    2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
       Michael Betancourt

    :param verlet_update: A callable to get a new integrator state given a current
        integrator state.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param verlet_state: Initial integrator state.
    :param inverse_mass_matrix: Inverse of the mass matrix.
    :param float step_size: Step size for the current trajectory.
    :param jax.random.PRNGKey rng_key: random key to be used as the source of
        randomness.
    :param float max_delta_energy: A threshold to decide if the new state diverges
        (based on the energy difference) too much from the initial integrator state.
    :return: information of the tree.
    :rtype: :data:`TreeInfo`
    """
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    latent_size = jnp.size(ravel_pytree(r)[0])
    r_ckpts = jnp.zeros((max_tree_depth, latent_size))
    r_sum_ckpts = jnp.zeros((max_tree_depth, latent_size))

    tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current,
                    depth=0, weight=jnp.zeros(()), r_sum=r, turning=jnp.array(False),
                    diverging=jnp.array(False),
                    sum_accept_probs=jnp.zeros(()),
                    num_proposals=jnp.array(0, dtype=jnp.result_type(int)))

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
                            going_right, doubling_key, energy_current, max_delta_energy,
                            r_ckpts, r_sum_ckpts)
        return tree, key

    state = (tree, rng_key)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree
示例#11
0
    def _find_valid_params(rng_key, exit_early=False):
        init_state = (0, rng_key, (prototype_params, 0., prototype_params), False)
        if exit_early and not_jax_tracer(rng_key):
            # Early return if valid params found. This is only helpful for single chain,
            # where we can avoid compiling body_fn in while_loop.
            _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
            if not_jax_tracer(is_valid):
                if device_get(is_valid):
                    return (init_params, pe, z_grad), is_valid

        # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
        # even if the init_state is a valid result
        _, _, (init_params, pe, z_grad), is_valid = while_loop(cond_fn, body_fn, init_state)
        return (init_params, pe, z_grad), is_valid
示例#12
0
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator,
                              inverse_mass_matrix, position, rng,
                              init_step_size):
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = np.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)

    def _body_fn(state):
        step_size, _, direction, rng = state
        rng, rng_momentum = random.split(rng)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0**direction) * step_size
        r = momentum_generator(inverse_mass_matrix, rng_momentum)
        _, r_new, potential_energy_new, _ = vv_update(
            step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix,
                                r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = np.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng

    def _cond_fn(state):
        return (state[1] == 0) | (state[1] == state[2])

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn,
                                    (init_step_size, 0, 0, rng))
    return step_size
示例#13
0
def find_valid_initial_params(rng_key,
                              model,
                              *model_args,
                              init_strategy=init_to_uniform(),
                              param_as_improper=False,
                              prototype_params=None,
                              **model_kwargs):
    """
    Given a model with Pyro primitives, returns an initial valid unconstrained
    parameters. This function also returns an `is_valid` flag to say whether the
    initial parameters are valid.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param callable init_strategy: a per-site initialization function.
    :param bool param_as_improper: a flag to decide whether to consider sites with
        `param` statement as sites with improper priors.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `is_valid`).
    """
    init_strategy = jax.partial(init_strategy,
                                skip_param=not param_as_improper)

    def cond_fn(state):
        i, _, _, is_valid = state
        return (i < 100) & (~is_valid)

    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model,
                                  substitute_fn=block(
                                      seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args,
                                                    **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(
                        transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v
                               for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, model_args,
                                   model_kwargs, inv_transforms)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid

    if prototype_params is not None:
        init_state = (0, rng_key, prototype_params, False)
    else:
        _, _, prototype_params, is_valid = init_state = body_fn(
            (0, rng_key, None, None))
        if not_jax_tracer(is_valid):
            if device_get(is_valid):
                return prototype_params, is_valid

    _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state)
    return init_params, is_valid