コード例 #1
0
ファイル: test_hmc_gibbs.py プロジェクト: mjbajwa/numpyro
def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size):
    true_loc = jnp.array([0.3, 0.1, 0.9])
    num_warmup, num_samples = 200, 200
    data = true_loc + dist.Normal(jnp.zeros(3, ), jnp.ones(3, )).sample(
        random.PRNGKey(1), (10000, ))

    def model(data, subsample_size):
        mean = numpyro.sample('mean', dist.Normal().expand((3, )).to_event(1))
        with numpyro.plate('batch',
                           data.shape[0],
                           dim=-2,
                           subsample_size=subsample_size):
            sub_data = numpyro.subsample(data, 0)
            numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data)

    ref_params = {
        'mean':
        true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0))
    }
    proxy_fn = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(random.PRNGKey(0), data, subsample_size)

    samples = mcmc.get_samples()
    assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0),
                    true_loc,
                    atol=0.1)
    assert len(samples['mean']) == num_samples
コード例 #2
0
def test_subsample_gibbs_partitioning(kernel_cls, num_blocks):
    def model(obs):
        with plate('N', obs.shape[0], subsample_size=100) as idx:
            numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx])

    obs = random.normal(random.PRNGKey(0), (10000, )) / 100
    kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks)
    state = kernel.init(random.PRNGKey(1),
                        10,
                        None,
                        model_args=(obs, ),
                        model_kwargs=None)
    gibbs_sites = {'N': jnp.arange(100)}

    def potential_fn(z_gibbs, z_hmc):
        return kernel.inner_kernel._potential_fn_gen(
            obs, _gibbs_sites=z_gibbs)(z_hmc)

    gibbs_fn = numpyro.infer.hmc_gibbs._subsample_gibbs_fn(
        potential_fn, kernel._plate_sizes, num_blocks)
    new_gibbs_sites, _ = gibbs_fn(
        random.PRNGKey(2), gibbs_sites, state.hmc_state.z,
        state.hmc_state.potential_energy)  # accept_prob > .999
    block_size = 100 // num_blocks
    for name in gibbs_sites:
        assert block_size == jnp.not_equal(gibbs_sites[name],
                                           new_gibbs_sites[name]).sum()
コード例 #3
0
def test_estimate_likelihood(kernel_cls):
    data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = 0.1
    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (10_000,)
    )
    n, _ = data.shape
    num_warmup = 200
    num_samples = 200
    num_blocks = 20

    def model(data):
        mean = numpyro.sample(
            "mean", dist.Normal(ref_params, jnp.ones_like(ref_params))
        )
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx:
            numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])

    proxy_fn = HMCECS.taylor_proxy({"mean": ref_params})
    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)

    mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"])

    pes = mcmc.get_extra_fields()["hmc_state.potential_energy"]
    samples = mcmc.get_samples()
    pes_full = vmap(
        lambda sample: log_density(
            model, (data,), {}, {**sample, **{"N": jnp.arange(n)}}
        )[0]
    )(samples)

    assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
コード例 #4
0
ファイル: hmcecs.py プロジェクト: pyro-ppl/numpyro
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
    svi_key, mcmc_key = random.split(hmcecs_key)

    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    guide = autoguide.AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    svi_result = svi.run(svi_key, args.num_svi_steps, data, obs,
                         args.subsample_size)
    params, losses = svi_result.params, svi_result.losses
    ref_params = {"theta": params["theta_auto_loc"]}

    # taylor proxy estimates log likelihood (ll) by
    # taylor_expansion(ll, theta_curr) +
    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params
    proxy = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)

    mcmc.run(mcmc_key, data, obs, args.subsample_size)
    mcmc.print_summary()
    return losses, mcmc.get_samples()
コード例 #5
0
ファイル: test_hmc_gibbs.py プロジェクト: mjbajwa/numpyro
def test_enum_subsample_smoke():
    def model(data):
        x = numpyro.sample("x", dist.Bernoulli(0.5))
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1):
            batch = numpyro.subsample(data, event_dim=0)
            numpyro.sample("obs", dist.Normal(x, 1), obs=batch)

    data = random.normal(random.PRNGKey(0), (10000, )) + 1
    kernel = HMCECS(NUTS(model), num_blocks=10)
    mcmc = MCMC(kernel, 10, 10)
    mcmc.run(random.PRNGKey(0), data)
コード例 #6
0
def test_subsample_gibbs_partitioning(kernel_cls, num_blocks):
    def model(obs):
        with plate('N', obs.shape[0], subsample_size=100) as idx:
            numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx])

    obs = random.normal(random.PRNGKey(0), (10000,)) / 100
    kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks)
    hmc_state = kernel.init(random.PRNGKey(1), 10, None, model_args=(obs,), model_kwargs=None)
    gibbs_sites = {'N': jnp.arange(100)}

    gibbs_fn = kernel._gibbs_fn
    new_gibbs_sites = gibbs_fn(random.PRNGKey(2), gibbs_sites, hmc_state.z)  # accept_prob > .999
    block_size = 100 // num_blocks
    for name in gibbs_sites:
        assert block_size == jnp.not_equal(gibbs_sites[name], new_gibbs_sites[name]).sum()
コード例 #7
0
def test_hmcecs_multiple_plates():
    true_loc = jnp.array([0.3, 0.1, 0.9])
    num_warmup, num_samples = 2, 2
    data = true_loc + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        random.PRNGKey(1), (1000,)
    )

    def model(data):
        mean = numpyro.sample("mean", dist.Normal().expand((3,)).to_event(1))
        with numpyro.plate("batch", data.shape[0], dim=-2, subsample_size=10):
            sub_data = numpyro.subsample(data, 0)
            with numpyro.plate("dim", 3):
                numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data)

    ref_params = {
        "mean": true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0))
    }
    proxy_fn = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(NUTS(model), proxy=proxy_fn)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0), data)
コード例 #8
0
def benchmark_hmc(args, features, labels):
    rng_key = random.PRNGKey(1)
    start = time.time()
    # a MAP estimate at the following source
    # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
    ref_params = {
        "coefs":
        jnp.array([
            +2.03420663e00,
            -3.53567265e-02,
            -1.49223924e-01,
            -3.07049364e-01,
            -1.00028366e-01,
            -1.46827862e-01,
            -1.64167881e-01,
            -4.20344204e-01,
            +9.47479829e-02,
            -1.12681836e-02,
            +2.64442056e-01,
            -1.22087866e-01,
            -6.00568838e-02,
            -3.79419506e-01,
            -1.06668741e-01,
            -2.97053963e-01,
            -2.05253899e-01,
            -4.69537191e-02,
            -2.78072730e-02,
            -1.43250525e-01,
            -6.77954629e-02,
            -4.34899796e-03,
            +5.90927452e-02,
            +7.23133609e-02,
            +1.38526391e-02,
            -1.24497898e-01,
            -1.50733739e-02,
            -2.68872194e-02,
            -1.80925727e-02,
            +3.47936489e-02,
            +4.03552800e-02,
            -9.98773426e-03,
            +6.20188080e-02,
            +1.15002751e-01,
            +1.32145107e-01,
            +2.69109547e-01,
            +2.45785132e-01,
            +1.19035013e-01,
            -2.59744357e-02,
            +9.94279515e-04,
            +3.39266285e-02,
            -1.44057125e-02,
            -6.95222765e-02,
            -7.52013028e-02,
            +1.21171586e-01,
            +2.29205526e-02,
            +1.47308692e-01,
            -8.34354162e-02,
            -9.34122875e-02,
            -2.97472421e-02,
            -3.03937674e-01,
            -1.70958012e-01,
            -1.59496680e-01,
            -1.88516974e-01,
            -1.20889175e00,
        ])
    }
    if args.algo == "HMC":
        step_size = jnp.sqrt(0.5 / features.shape[0])
        trajectory_length = step_size * args.num_steps
        kernel = HMC(
            model,
            step_size=step_size,
            trajectory_length=trajectory_length,
            adapt_step_size=False,
            dense_mass=args.dense_mass,
        )
        subsample_size = None
    elif args.algo == "NUTS":
        kernel = NUTS(model, dense_mass=args.dense_mass)
        subsample_size = None
    elif args.algo == "HMCECS":
        subsample_size = 1000
        inner_kernel = NUTS(
            model,
            init_strategy=init_to_value(values=ref_params),
            dense_mass=args.dense_mass,
        )
        # note: if num_blocks=100, we'll update 10 index at each MCMC step
        # so it took 50000 MCMC steps to iterative the whole dataset
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(ref_params))
    elif args.algo == "SA":
        # NB: this kernel requires large num_warmup and num_samples
        # and running on GPU is much faster than on CPU
        kernel = SA(model,
                    adapt_state_size=1000,
                    init_strategy=init_to_value(values=ref_params))
        subsample_size = None
    elif args.algo == "FlowHMCECS":
        subsample_size = 1000
        guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
        svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
        params, losses = svi_result.params, svi_result.losses
        plt.plot(losses)
        plt.show()

        neutra = NeuTraReparam(guide, params)
        neutra_model = neutra.reparam(model)
        neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
        # no need to adapt mass matrix if the flow does a good job
        inner_kernel = NUTS(
            neutra_model,
            init_strategy=init_to_value(values=neutra_ref_params),
            adapt_mass_matrix=False,
        )
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(neutra_ref_params))
    else:
        raise ValueError(
            "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(rng_key,
             features,
             labels,
             subsample_size,
             extra_fields=("accept_prob", ))
    print("Mean accept prob:",
          jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
    mcmc.print_summary(exclude_deterministic=False)
    print("\nMCMC elapsed time:", time.time() - start)
コード例 #9
0
def test_pickle_hmcecs():
    mcmc = MCMC(HMCECS(NUTS(logistic_regression)), num_warmup=10, num_samples=10)
    mcmc.run(random.PRNGKey(0))
    pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
    test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())