def test_ravel_pytree(pytree): flat, unravel_fn = ravel_pytree(pytree) unravel = unravel_fn(flat) tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all(tree_flatten(tree_multimap(lambda x, y: canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)), unravel, pytree))[0])
def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=partial(potential_fn, z_hmc=hmc_sites), idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond(random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, pe) _, gibbs_sites, pe = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites, pe
def proxy_fn(params, subsample_lik_sites, gibbs_state): params_flat, _ = ravel_pytree(params) params_diff = params_flat - ref_params_flat ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) for name in subsample_lik_sites: proxy_subsample[name] = ( ref_subsample_log_liks[name] + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + 0.5 * jnp.dot( jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), params_diff, )) proxy_sum[name] = ( ref_log_likelihoods_sum[name] + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + 0.5 * jnp.dot( jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), params_diff, )) return proxy_sum, proxy_subsample
def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): # idx: current index of `z_discrete_flat` to update # support_size: support size of z_discrete at the index idx z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) # Here we loop over the support of z_flat[idx] to get z_new # XXX: we can't vmap potential_fn over all proposals and sample from the conditional # categorical distribution because support_size is a traced value, i.e. its value # might change across different discrete variables; # so here we will loop over all proposals and use an online scheme to sample from # the conditional categorical distribution body_fn = partial( _discrete_gibbs_proposal_body_fn, z_discrete_flat, unravel_fn, pe, potential_fn, idx, ) init_val = (rng_key, z_discrete, pe, jnp.array(0.0)) rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val) log_accept_ratio = jnp.array(0.0) return rng_key, z_new, pe_new, log_accept_ratio
def _discrete_modified_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0): assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1 z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) body_fn = partial( _discrete_gibbs_proposal_body_fn, z_discrete_flat, unravel_fn, pe, potential_fn, idx, ) # like gibbs_step but here, weight of the current value is 0 init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf)) rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1, body_fn, init_val) rng_key, rng_stay = random.split(rng_key) z_new, pe_new = cond( random.bernoulli(rng_stay, stay_prob), (z_discrete, pe), identity, (z_new, pe_new), identity, ) # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new)) # where 1 - P(z) ~ weight_sum # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight log_accept_ratio = log_weight_sum - jnp.log( jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new)) return rng_key, z_new, pe_new, log_accept_ratio
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): rng_key, rng_r = random.split(rng_key) state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) self._support_sizes_flat, _ = ravel_pytree( {k: self._support_sizes[k] for k in self._gibbs_sites}) if self._num_discrete_updates is None: self._num_discrete_updates = self._support_sizes_flat.shape[0] self._num_warmup = num_warmup # NB: the warmup adaptation can not be performed in sub-trajectories (i.e. the hmc trajectory # between two discrete updates), so we will do it here, at the end of each MixedHMC step. _, self._wa_update = warmup_adapter( num_warmup, adapt_step_size=self.inner_kernel._adapt_step_size, adapt_mass_matrix=self.inner_kernel._adapt_mass_matrix, dense_mass=self.inner_kernel._dense_mass, target_accept_prob=self.inner_kernel._target_accept_prob, find_reasonable_step_size=None) # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momemtum at the # beginning of an HMC step. The reason is we need to maintain `r` between each sub-trajectories. r = momentum_generator(state.hmc_state.z, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r) return MixedHMCState(state.z, state.hmc_state._replace(r=r), state.rng_key, jnp.zeros(()))
def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): rng_key, rng_proposal = random.split(rng_key, 2) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size) z_new_flat = ops.index_update(z_discrete_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_accept_ratio = pe - pe_new return rng_key, z_new, pe_new, log_accept_ratio
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # convert to unconstrained values z_hmc = { k: biject_to(prototype_trace[k]["fn"].support).inv(v) for k, v in hmc_sites.items() if k in prototype_trace and prototype_trace[k]["type"] == "sample" } use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 wrapped_model = _wrap_model(model) if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): z_hmc = hmc_sites use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model_ = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) else: wrapped_model_ = wrapped_model def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model_, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def _discrete_modified_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.): assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1 rng_key, rng_proposal, rng_stay = random.split(rng_key, 3) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1) proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i) proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal) z_new_flat = ops.index_update(z_discrete_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_accept_ratio = pe - pe_new return rng_key, z_new, pe_new, log_accept_ratio
def single_particle_grad(particle, att_forces, rep_forces): reparam_jac = { k: jax.tree_map( lambda variable: jax.jacfwd(self.particle_transforms[k].inv )(variable), variables, ) for k, variables in unravel_pytree(particle).items() } jac_params = jax.tree_multimap( lambda af, rf, rjac: ((af.reshape(-1) + rf.reshape(-1)) @ rjac.reshape( (_numel(rjac.shape[:len(rjac.shape) // 2]), -1))).reshape( rf.shape), unravel_pytree(att_forces), unravel_pytree(rep_forces), reparam_jac, ) jac_particle, _ = ravel_pytree(jac_params) return jac_particle
def log_likelihood(params, subsample_indices=None): params_flat, unravel_fn = ravel_pytree(params) if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") with block(), trace( ) as tr, substitute(data=subsample_indices), substitute( substitute_fn=partial(_unconstrain_reparam, params)): model(*model_args, **model_kwargs) log_lik = defaultdict(float) for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in subsample_plate_sizes: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik
def construct_proxy_fn(prototype_trace, subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): ref_params = {name: biject_to(prototype_trace[name]["fn"].support).inv(value) for name, value in reference_params.items()} ref_params_flat, unravel_fn = ravel_pytree(ref_params) def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") params = {name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items()} with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) else: log_lik[frame.name] = _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik def log_likelihood_sum(params_flat, subsample_indices=None): return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()} # those stats are dict keyed by subsample names ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) def gibbs_init(rng_key, gibbs_sites): ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites) ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, gibbs_sites) ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, gibbs_sites) return TaylorProxyState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians) def gibbs_update(rng_key, gibbs_sites, gibbs_state): u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) new_states = defaultdict(dict) ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs) ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs) for stat, new_block_values, last_values in zip( ["log_liks", "grads", "hessians"], [ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians], [gibbs_state.ref_subsample_log_liks, gibbs_state.ref_subsample_log_lik_grads, gibbs_state.ref_subsample_log_lik_hessians]): for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] pad, start = pads[name], starts[name] new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) new_value = lax.dynamic_update_slice_in_dim( new_value, new_block_values[name], start, 0) new_states[stat][name] = new_value[:subsample_size] gibbs_state = TaylorProxyState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) return u_new, gibbs_state def proxy_fn(params, subsample_lik_sites, gibbs_state): params_flat, _ = ravel_pytree(params) params_diff = params_flat - ref_params_flat ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) for name in subsample_lik_sites: proxy_subsample[name] = (ref_subsample_log_liks[name] + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), params_diff)) proxy_sum[name] = (ref_log_likelihoods_sum[name] + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), params_diff)) return proxy_sum, proxy_subsample return proxy_fn, gibbs_init, gibbs_update
def particle_transform_fn(particle): params = unravel_pytree(particle) tparams = self.particle_transform_fn(params) tparticle, _ = ravel_pytree(tparams) return tparticle