def test_subsample_gibbs_partitioning(kernel_cls, num_blocks): def model(obs): with plate('N', obs.shape[0], subsample_size=100) as idx: numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx]) obs = random.normal(random.PRNGKey(0), (10000, )) / 100 kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks) state = kernel.init(random.PRNGKey(1), 10, None, model_args=(obs, ), model_kwargs=None) gibbs_sites = {'N': jnp.arange(100)} def potential_fn(z_gibbs, z_hmc): return kernel.inner_kernel._potential_fn_gen( obs, _gibbs_sites=z_gibbs)(z_hmc) gibbs_fn = numpyro.infer.hmc_gibbs._subsample_gibbs_fn( potential_fn, kernel._plate_sizes, num_blocks) new_gibbs_sites, _ = gibbs_fn( random.PRNGKey(2), gibbs_sites, state.hmc_state.z, state.hmc_state.potential_energy) # accept_prob > .999 block_size = 100 // num_blocks for name in gibbs_sites: assert block_size == jnp.not_equal(gibbs_sites[name], new_gibbs_sites[name]).sum()
def test_subsample_gibbs_partitioning(kernel_cls, num_blocks): def model(obs): with plate('N', obs.shape[0], subsample_size=100) as idx: numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx]) obs = random.normal(random.PRNGKey(0), (10000,)) / 100 kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks) hmc_state = kernel.init(random.PRNGKey(1), 10, None, model_args=(obs,), model_kwargs=None) gibbs_sites = {'N': jnp.arange(100)} gibbs_fn = kernel._gibbs_fn new_gibbs_sites = gibbs_fn(random.PRNGKey(2), gibbs_sites, hmc_state.z) # accept_prob > .999 block_size = 100 // num_blocks for name in gibbs_sites: assert block_size == jnp.not_equal(gibbs_sites[name], new_gibbs_sites[name]).sum()