Ejemplo n.º 1
0
def run_inference(model, at_bats, hits, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    elif args.algo == "SA":
        kernel = SA(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if (
                    "NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
Ejemplo n.º 2
0
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)

    init_params = jnp.array(0.)
    if kernel_cls is SA:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    elif kernel_cls is BarkerMH:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)

    if 'JAX_ENABLE_X64' in os.environ:
        assert hmc_states.dtype == jnp.float64
Ejemplo n.º 3
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    if kernel_cls is SA:
        num_warmup, num_samples = (100000, 100000)
    elif kernel_cls is BarkerMH:
        num_warmup, num_samples = (2000, 12000)
    else:
        num_warmup, num_samples = (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data,
                                                         axis=-1))
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model)
    else:
        kernel = kernel_cls(model=model,
                            trajectory_length=8,
                            find_heuristic_step_size=True)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples["logits"].shape == (num_samples, N)
    # those coefficients are found by doing MAP inference using AutoDelta
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.1)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["coefs"].dtype == jnp.float64
Ejemplo n.º 4
0
def run_inference(
    model: Callable,
    at_bats: jnp.ndarray,
    hits: jnp.ndarray,
    rng_key: jnp.ndarray,
    *,
    num_warmup: int = 1500,
    num_samples: int = 3000,
    num_chains: int = 1,
    algo_name: str = "NUTS",
) -> Dict[str, jnp.ndarray]:

    if algo_name == "NUTS":
        kernel = NUTS(model)
    elif algo_name == "HMC":
        kernel = HMC(model)
    elif algo_name == "SA":
        kernel = SA(model)
    else:
        raise ValueError("Unknown algorithm name")

    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
Ejemplo n.º 5
0
def test_beta_bernoulli_x64(kernel_cls):
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    if kernel_cls is SA:
        kernel = SA(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == jnp.float64
Ejemplo n.º 6
0
def benchmark_hmc(args, features, labels):
    rng_key = random.PRNGKey(1)
    start = time.time()
    # a MAP estimate at the following source
    # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
    ref_params = {
        "coefs":
        jnp.array([
            +2.03420663e00,
            -3.53567265e-02,
            -1.49223924e-01,
            -3.07049364e-01,
            -1.00028366e-01,
            -1.46827862e-01,
            -1.64167881e-01,
            -4.20344204e-01,
            +9.47479829e-02,
            -1.12681836e-02,
            +2.64442056e-01,
            -1.22087866e-01,
            -6.00568838e-02,
            -3.79419506e-01,
            -1.06668741e-01,
            -2.97053963e-01,
            -2.05253899e-01,
            -4.69537191e-02,
            -2.78072730e-02,
            -1.43250525e-01,
            -6.77954629e-02,
            -4.34899796e-03,
            +5.90927452e-02,
            +7.23133609e-02,
            +1.38526391e-02,
            -1.24497898e-01,
            -1.50733739e-02,
            -2.68872194e-02,
            -1.80925727e-02,
            +3.47936489e-02,
            +4.03552800e-02,
            -9.98773426e-03,
            +6.20188080e-02,
            +1.15002751e-01,
            +1.32145107e-01,
            +2.69109547e-01,
            +2.45785132e-01,
            +1.19035013e-01,
            -2.59744357e-02,
            +9.94279515e-04,
            +3.39266285e-02,
            -1.44057125e-02,
            -6.95222765e-02,
            -7.52013028e-02,
            +1.21171586e-01,
            +2.29205526e-02,
            +1.47308692e-01,
            -8.34354162e-02,
            -9.34122875e-02,
            -2.97472421e-02,
            -3.03937674e-01,
            -1.70958012e-01,
            -1.59496680e-01,
            -1.88516974e-01,
            -1.20889175e00,
        ])
    }
    if args.algo == "HMC":
        step_size = jnp.sqrt(0.5 / features.shape[0])
        trajectory_length = step_size * args.num_steps
        kernel = HMC(
            model,
            step_size=step_size,
            trajectory_length=trajectory_length,
            adapt_step_size=False,
            dense_mass=args.dense_mass,
        )
        subsample_size = None
    elif args.algo == "NUTS":
        kernel = NUTS(model, dense_mass=args.dense_mass)
        subsample_size = None
    elif args.algo == "HMCECS":
        subsample_size = 1000
        inner_kernel = NUTS(
            model,
            init_strategy=init_to_value(values=ref_params),
            dense_mass=args.dense_mass,
        )
        # note: if num_blocks=100, we'll update 10 index at each MCMC step
        # so it took 50000 MCMC steps to iterative the whole dataset
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(ref_params))
    elif args.algo == "SA":
        # NB: this kernel requires large num_warmup and num_samples
        # and running on GPU is much faster than on CPU
        kernel = SA(model,
                    adapt_state_size=1000,
                    init_strategy=init_to_value(values=ref_params))
        subsample_size = None
    elif args.algo == "FlowHMCECS":
        subsample_size = 1000
        guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
        svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
        params, losses = svi_result.params, svi_result.losses
        plt.plot(losses)
        plt.show()

        neutra = NeuTraReparam(guide, params)
        neutra_model = neutra.reparam(model)
        neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
        # no need to adapt mass matrix if the flow does a good job
        inner_kernel = NUTS(
            neutra_model,
            init_strategy=init_to_value(values=neutra_ref_params),
            adapt_mass_matrix=False,
        )
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(neutra_ref_params))
    else:
        raise ValueError(
            "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(rng_key,
             features,
             labels,
             subsample_size,
             extra_fields=("accept_prob", ))
    print("Mean accept prob:",
          jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
    mcmc.print_summary(exclude_deterministic=False)
    print("\nMCMC elapsed time:", time.time() - start)