Exemplo n.º 1
0
def test_subsample_gibbs_partitioning(kernel_cls, num_blocks):
    def model(obs):
        with plate('N', obs.shape[0], subsample_size=100) as idx:
            numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx])

    obs = random.normal(random.PRNGKey(0), (10000, )) / 100
    kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks)
    state = kernel.init(random.PRNGKey(1),
                        10,
                        None,
                        model_args=(obs, ),
                        model_kwargs=None)
    gibbs_sites = {'N': jnp.arange(100)}

    def potential_fn(z_gibbs, z_hmc):
        return kernel.inner_kernel._potential_fn_gen(
            obs, _gibbs_sites=z_gibbs)(z_hmc)

    gibbs_fn = numpyro.infer.hmc_gibbs._subsample_gibbs_fn(
        potential_fn, kernel._plate_sizes, num_blocks)
    new_gibbs_sites, _ = gibbs_fn(
        random.PRNGKey(2), gibbs_sites, state.hmc_state.z,
        state.hmc_state.potential_energy)  # accept_prob > .999
    block_size = 100 // num_blocks
    for name in gibbs_sites:
        assert block_size == jnp.not_equal(gibbs_sites[name],
                                           new_gibbs_sites[name]).sum()
Exemplo n.º 2
0
def test_subsample_gibbs_partitioning(kernel_cls, num_blocks):
    def model(obs):
        with plate('N', obs.shape[0], subsample_size=100) as idx:
            numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx])

    obs = random.normal(random.PRNGKey(0), (10000,)) / 100
    kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks)
    hmc_state = kernel.init(random.PRNGKey(1), 10, None, model_args=(obs,), model_kwargs=None)
    gibbs_sites = {'N': jnp.arange(100)}

    gibbs_fn = kernel._gibbs_fn
    new_gibbs_sites = gibbs_fn(random.PRNGKey(2), gibbs_sites, hmc_state.z)  # accept_prob > .999
    block_size = 100 // num_blocks
    for name in gibbs_sites:
        assert block_size == jnp.not_equal(gibbs_sites[name], new_gibbs_sites[name]).sum()