def _leaf_idx_to_ckpt_idxs(n): # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 _, idx_max = while_loop(lambda nc: nc[0] > 0, lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0)) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 _, num_subtrees = while_loop(lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)) idx_min = idx_max - num_subtrees + 1 return idx_min, idx_max
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix, position, rng, init_step_size): """ Finds a reasonable step size by tuning `init_step_size`. This function is used to avoid working with a too large or too small step size in HMC. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman :param potential_fn: A callable to compute potential energy. :param kinetic_fn: A callable to compute kinetic energy. :param momentum_generator: A generator to get a random momentum variable. :param inverse_mass_matrix: Inverse of mass matrix. :param position: Current position of the particle. :param jax.random.PRNGKey rng: Random key to be used as the source of randomness. :param float init_step_size: Initial step size to be tuned. :return: a reasonable value for step size. :rtype: float """ # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. target_accept_prob = np.log(0.8) _, vv_update = velocity_verlet(potential_fn, kinetic_fn) z = position potential_energy, z_grad = value_and_grad(potential_fn)(z) tiny = np.finfo(get_dtype(init_step_size)).tiny def _body_fn(state): step_size, _, direction, rng = state rng, rng_momentum = random.split(rng) # scale step_size: increase 2x or decrease 2x depends on direction; # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN`, which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). step_size = (2.0 ** direction) * step_size r = momentum_generator(inverse_mass_matrix, rng_momentum) _, r_new, potential_energy_new, _ = vv_update(step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad)) energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new delta_energy = energy_new - energy_current direction_new = np.where(target_accept_prob < -delta_energy, 1, -1) return step_size, direction, direction_new, rng def _cond_fn(state): step_size, last_direction, direction, _ = state # condition to run only if step_size is not so small or we are not decreasing step_size not_small_step_size_cond = (step_size > tiny) | (direction >= 0) return not_small_step_size_cond & ((last_direction == 0) | (direction == last_direction)) step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng)) return step_size
def _find_valid_params(rng_key_): _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None)) # Early return if valid params found. if not_jax_tracer(is_valid): if device_get(is_valid): return prototype_params, is_valid _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state) return init_params, is_valid
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng, max_delta_energy=1000., max_tree_depth=10): """ **References:** [1] `The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo`, Matthew D. Hoffman, Andrew Gelman [2] `A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt """ z, r, potential_energy, z_grad = verlet_state energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r) r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]), dtype=inverse_mass_matrix.dtype) r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]), dtype=inverse_mass_matrix.dtype) tree = _TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, depth=0, weight=0., r_sum=r, turning=False, diverging=False, sum_accept_probs=0., num_proposals=0) def _cond_fn(state): tree, _ = state return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging def _body_fn(state): tree, key = state key, direction_key, doubling_key = random.split(key, 3) going_right = random.bernoulli(direction_key) tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, doubling_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts) return tree, key state = (tree, rng) tree, _ = while_loop(_cond_fn, _body_fn, state) return tree
def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den): conc = jnp.broadcast_to(conc, shape) eig = jnp.broadcast_to(eig, shape) b0 = jnp.broadcast_to(b0, shape) eigmin = jnp.broadcast_to(eigmin, shape) phi_den = jnp.broadcast_to(phi_den, shape) def update_fn(curr): i, done, phi, key = curr phi_key, key = random.split(key) accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) x /= jnp.linalg.norm( x, axis=1, keepdims=True ) # Angular Central Gaussian distribution lf = ( conc[:, :1] * (x[:, :1] - 1) + eigmin + log_I1( 0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2) ).squeeze(0) - phi_den ) assert lf.shape == shape lg_inv = ( 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) ) assert lg_inv.shape == lf.shape accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv) phi = jnp.where(accepted, x, phi) return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): return jnp.bitwise_and( curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done)), ) phi_state = while_loop( cond_fn, update_fn, PhiMarginalState( i=jnp.array(0), done=jnp.zeros(shape, dtype=bool), phi=jnp.empty(shape, dtype=float), key=rng_key, ), ) return PhiMarginalState( phi_state.i, phi_state.done, phi_state.phi, phi_state.key )
def _iterative_build_subtree(depth, vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, rng, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts): max_num_proposals = 2**depth def _cond_fn(state): tree, turning, _, _, _ = state return (tree.num_proposals < max_num_proposals) & ~turning & ~tree.diverging def _body_fn(state): current_tree, _, r_ckpts, r_sum_ckpts, rng = state rng, transition_rng = random.split(rng) z, r, z_grad = _get_leaf(current_tree, going_right) new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy) new_tree = _combine_tree(current_tree, new_leaf, inverse_mass_matrix, going_right, transition_rng, biased_transition=False) leaf_idx = current_tree.num_proposals ckpt_idx_min, ckpt_idx_max = _leaf_idx_to_ckpt_idxs(leaf_idx) r, _ = ravel_pytree(new_leaf.r_right) r_sum, _ = ravel_pytree(new_tree.r_sum) # we update checkpoints when leaf_idx is even r_ckpts, r_sum_ckpts = cond( leaf_idx % 2 == 0, (r_ckpts, r_sum_ckpts), lambda x: (index_update(x[0], ckpt_idx_max, r), index_update(x[1], ckpt_idx_max, r_sum)), (r_ckpts, r_sum_ckpts), lambda x: x) turning = _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts, ckpt_idx_min, ckpt_idx_max) return new_tree, turning, r_ckpts, r_sum_ckpts, rng basetree = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy) r_init, _ = ravel_pytree(basetree.r_left) r_ckpts = index_update(r_ckpts, 0, r_init) r_sum_ckpts = index_update(r_sum_ckpts, 0, r_init) tree, turning, _, _, _ = while_loop( _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng)) # update depth and turning condition return _TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad, tree.z_right, tree.r_right, tree.z_right_grad, tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, depth, tree.weight, tree.r_sum, turning, tree.diverging, tree.sum_accept_probs, tree.num_proposals)
def _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts, idx_min, idx_max): def _body_fn(state): i, _ = state subtree_r_sum = r_sum - r_sum_ckpts[i] + r_ckpts[i] return i - 1, _is_turning(inverse_mass_matrix, r_ckpts[i], r, subtree_r_sum) _, turning = while_loop(lambda it: (it[0] >= idx_min) & ~it[1], _body_fn, (idx_max, False)) return turning
def _leaf_idx_to_ckpt_idxs(n): # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 _, idx_max = while_loop(lambda nc: nc[0] > 0, lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0)) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 _, num_subtrees = while_loop(lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)) # TODO: explore the potential of setting idx_min=0 to allow more turning checks # It will be useful in case: e.g. assume a tree 0 -> 7 is a circle, # subtrees 0 -> 3, 4 -> 7 are half-circles, which two leaves might not # satisfy turning condition; # the full tree 0 -> 7 is a circle, which two leaves might also not satisfy # turning condition; # however, we can check the turning condition of the subtree 0 -> 5, which # likely satisfies turning condition because its trajectory 3/4 of a circle. # XXX: make sure that detailed balance is satisfied if we follow this direction idx_min = idx_max - num_subtrees + 1 return idx_min, idx_max
def _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts): max_num_proposals = 2 ** prototype_tree.depth def _cond_fn(state): tree, turning, _, _, _ = state return (tree.num_proposals < max_num_proposals) & ~turning & ~tree.diverging def _body_fn(state): current_tree, _, r_ckpts, r_sum_ckpts, rng_key = state rng_key, transition_rng_key = random.split(rng_key) # If we are going to the right, start from the right leaf of the current tree. z, r, z_grad = _get_leaf(current_tree, going_right) new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy) new_tree = cond(current_tree.num_proposals == 0, new_leaf, identity, (current_tree, new_leaf, inverse_mass_matrix, going_right, transition_rng_key), lambda x: _combine_tree(*x, False)) leaf_idx = current_tree.num_proposals # NB: in the special case leaf_idx=0, ckpt_idx_min=1 and ckpt_idx_max=0, # the following logic is still valid for that case ckpt_idx_min, ckpt_idx_max = _leaf_idx_to_ckpt_idxs(leaf_idx) r, unravel_fn = ravel_pytree(new_leaf.r_right) r_sum, _ = ravel_pytree(new_tree.r_sum) # we update checkpoints when leaf_idx is even r_ckpts, r_sum_ckpts = cond(leaf_idx % 2 == 0, (r_ckpts, r_sum_ckpts), lambda x: (index_update(x[0], ckpt_idx_max, r), index_update(x[1], ckpt_idx_max, r_sum)), (r_ckpts, r_sum_ckpts), identity) turning = _is_iterative_turning(inverse_mass_matrix, new_leaf.r_right, r_sum, r_ckpts, r_sum_ckpts, ckpt_idx_min, ckpt_idx_max, unravel_fn) return new_tree, turning, r_ckpts, r_sum_ckpts, rng_key basetree = prototype_tree._replace(num_proposals=0) tree, turning, _, _, _ = while_loop( _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng_key) ) # update depth and turning condition return TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad, tree.z_right, tree.r_right, tree.z_right_grad, tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy, prototype_tree.depth, tree.weight, tree.r_sum, turning, tree.diverging, tree.sum_accept_probs, tree.num_proposals)
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=1000., max_tree_depth=10): """ Builds a binary tree from the `verlet_state`. This is used in NUTS sampler. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman 2. *A Conceptual Introduction to Hamiltonian Monte Carlo*, Michael Betancourt :param verlet_update: A callable to get a new integrator state given a current integrator state. :param kinetic_fn: A callable to compute kinetic energy. :param verlet_state: Initial integrator state. :param inverse_mass_matrix: Inverse of the mass matrix. :param float step_size: Step size for the current trajectory. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. :param float max_delta_energy: A threshold to decide if the new state diverges (based on the energy difference) too much from the initial integrator state. :return: information of the tree. :rtype: :data:`TreeInfo` """ z, r, potential_energy, z_grad = verlet_state energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r) latent_size = jnp.size(ravel_pytree(r)[0]) r_ckpts = jnp.zeros((max_tree_depth, latent_size)) r_sum_ckpts = jnp.zeros((max_tree_depth, latent_size)) tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current, depth=0, weight=jnp.zeros(()), r_sum=r, turning=jnp.array(False), diverging=jnp.array(False), sum_accept_probs=jnp.zeros(()), num_proposals=jnp.array(0, dtype=jnp.result_type(int))) def _cond_fn(state): tree, _ = state return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging def _body_fn(state): tree, key = state key, direction_key, doubling_key = random.split(key, 3) going_right = random.bernoulli(direction_key) tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, doubling_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts) return tree, key state = (tree, rng_key) tree, _ = while_loop(_cond_fn, _body_fn, state) return tree
def _find_valid_params(rng_key, exit_early=False): init_state = (0, rng_key, (prototype_params, 0., prototype_params), False) if exit_early and not_jax_tracer(rng_key): # Early return if valid params found. This is only helpful for single chain, # where we can avoid compiling body_fn in while_loop. _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state) if not_jax_tracer(is_valid): if device_get(is_valid): return (init_params, pe, z_grad), is_valid # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times # even if the init_state is a valid result _, _, (init_params, pe, z_grad), is_valid = while_loop(cond_fn, body_fn, init_state) return (init_params, pe, z_grad), is_valid
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix, position, rng, init_step_size): # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. target_accept_prob = np.log(0.8) _, vv_update = velocity_verlet(potential_fn, kinetic_fn) z = position potential_energy, z_grad = value_and_grad(potential_fn)(z) def _body_fn(state): step_size, _, direction, rng = state rng, rng_momentum = random.split(rng) # scale step_size: increase 2x or decrease 2x depends on direction; # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN`, which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). step_size = (2.0**direction) * step_size r = momentum_generator(inverse_mass_matrix, rng_momentum) _, r_new, potential_energy_new, _ = vv_update( step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad)) energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new delta_energy = energy_new - energy_current direction_new = np.where(target_accept_prob < -delta_energy, 1, -1) return step_size, direction, direction_new, rng def _cond_fn(state): return (state[1] == 0) | (state[1] == state[2]) step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng)) return step_size
def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to_uniform(), param_as_improper=False, prototype_params=None, **model_kwargs): """ Given a model with Pyro primitives, returns an initial valid unconstrained parameters. This function also returns an `is_valid` flag to say whether the initial parameters are valid. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. :param bool param_as_improper: a flag to decide whether to consider sites with `param` statement as sites with improper priors. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `is_valid`). """ init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper) def cond_fn(state): i, _, _, is_valid = state return (i < 100) & (~is_valid) def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block( seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform( transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid if prototype_params is not None: init_state = (0, rng_key, prototype_params, False) else: _, _, prototype_params, is_valid = init_state = body_fn( (0, rng_key, None, None)) if not_jax_tracer(is_valid): if device_get(is_valid): return prototype_params, is_valid _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state) return init_params, is_valid