Beispiel #1
0
def test_gaussian_hmm():
    dim = 4
    num_steps = 10

    def model(data):
        with numpyro.plate("states", dim):
            transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
            emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
            emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))

        trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
        for t, y in markov(enumerate(data)):
            x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
            numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            trans_prob = transition[x]

    def _generate_data():
        transition_probs = np.random.rand(dim, dim)
        transition_probs = transition_probs / transition_probs.sum(-1, keepdims=True)
        emissions_loc = np.arange(dim)
        emissions_scale = 1.
        state = np.random.choice(3)
        obs = [np.random.normal(emissions_loc[state], emissions_scale)]
        for _ in range(num_steps - 1):
            state = np.random.choice(dim, p=transition_probs[state])
            obs.append(np.random.normal(emissions_loc[state], emissions_scale))
        return np.stack(obs)

    data = _generate_data()
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(0), data)
Beispiel #2
0
def test_change_point_x64():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = np.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(4), count_data)
    samples = mcmc.get_samples()
    tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32)
    tau_values, counts = onp.unique(tau_posterior, return_counts=True)
    mode_ind = np.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['lambda1'].dtype == np.float64
        assert samples['lambda2'].dtype == np.float64
        assert samples['tau'].dtype == np.float64
Beispiel #3
0
def test_change_point():
    def model(count_data):
        n_count_data = count_data.shape[0]
        alpha = 1 / jnp.mean(count_data)
        lambda_1 = numpyro.sample('lambda_1', dist.Exponential(alpha))
        lambda_2 = numpyro.sample('lambda_2', dist.Exponential(alpha))
        # this is the same as DiscreteUniform(0, 69)
        tau = numpyro.sample('tau', dist.Categorical(logits=jnp.zeros(70)))
        idx = jnp.arange(n_count_data)
        lambda_ = jnp.where(tau > idx, lambda_1, lambda_2)
        with numpyro.plate("data", n_count_data):
            numpyro.sample('obs', dist.Poisson(lambda_), obs=count_data)

    count_data = jnp.array([
        13, 24, 8, 24,  7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11,
        19, 29, 6, 19, 12, 22, 12, 18, 72, 32,  9,  7, 13, 19, 23,
        27, 20, 6, 17, 13, 10, 14,  6, 16, 15,  7,  2, 15, 15, 19,
        70, 49, 7, 53, 22, 21, 31, 19, 11,  1, 20, 12, 35, 17, 23,
        17,  4, 2, 31, 30, 13, 27,  0, 39, 37,  5, 14, 13, 22,
    ])

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(0), count_data)
    samples = mcmc.get_samples()
    assert_allclose(samples["lambda_1"].mean(0), 18., atol=1.)
    assert_allclose(samples["lambda_2"].mean(0), 23., atol=1.)
Beispiel #4
0
 def test_spire_model(self):
     nuts_kernel = NUTS(SPIRE.spire_model)
     mcmc = MCMC(nuts_kernel,num_samples=100,num_warmup=100)
     rng_key = random.PRNGKey(0)
     mcmc.run(rng_key,self.priors )
     posterior_samples = mcmc.get_samples()
     self.assertIsNotNone(posterior_samples)
     self.assertEqual(posterior_samples['src_f'].shape[1], self.priors[0].nsrc)
Beispiel #5
0
def run_inference(model, args, rng_key, X, Y, D_H):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
    mcmc.run(rng_key, X, Y, D_H)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
Beispiel #6
0
def run_inference(model, at_bats, hits, rng_key, args):
    kernel = NUTS(model)
    mcmc = MCMC(kernel,
                args.num_warmup,
                args.num_samples,
                num_chains=args.num_chains)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
Beispiel #7
0
def test_compile_warmup_run(num_chains, chain_method, progress_bar):
    def model():
        numpyro.sample("x", dist.Normal(0, 1))

    if num_chains == 1 and chain_method in ["sequential", "vectorized"]:
        pytest.skip("duplicated test")
    if num_chains > 1 and chain_method == "parallel":
        pytest.skip("duplicated test")

    rng_key = random.PRNGKey(0)
    num_samples = 10
    mcmc = MCMC(
        NUTS(model),
        num_warmup=10,
        num_samples=num_samples,
        num_chains=num_chains,
        chain_method=chain_method,
        progress_bar=progress_bar,
    )

    mcmc.run(rng_key)
    expected_samples = mcmc.get_samples()["x"]

    mcmc._compile(rng_key)
    # no delay after compiling
    mcmc.warmup(rng_key)
    mcmc.run(mcmc.last_state.rng_key)
    actual_samples = mcmc.get_samples()["x"]

    assert_allclose(actual_samples, expected_samples)

    # test for reproducible
    if num_chains > 1:
        mcmc = MCMC(
            NUTS(model),
            num_warmup=10,
            num_samples=num_samples,
            num_chains=1,
            progress_bar=progress_bar,
        )
        rng_key = random.split(rng_key)[0]
        mcmc.run(rng_key)
        first_chain_samples = mcmc.get_samples()["x"]
        assert_allclose(actual_samples[:num_samples],
                        first_chain_samples,
                        atol=1e-5)
Beispiel #8
0
def test_empty_model(num_chains, chain_method, progress_bar):
    def model():
        pass

    mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=num_chains,
                chain_method=chain_method, progress_bar=progress_bar)
    mcmc.run(random.PRNGKey(0))
    assert mcmc.get_samples() == {}
Beispiel #9
0
def test_forward_mode_differentiation():
    def model():
        x = numpyro.sample("x", dist.Normal(0, 1))
        y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x)
        numpyro.sample("obs", dist.Normal(y, 1), obs=1.)

    # this fails in reverse mode
    mcmc = MCMC(NUTS(model, forward_mode_differentiation=True), 10, 10)
    mcmc.run(random.PRNGKey(0))
Beispiel #10
0
def test_loose_warning_for_missing_plate():
    def model():
        x = numpyro.sample("x", dist.Normal(0, 1))
        with numpyro.plate("N", 10):
            numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones((5, 10)))

    mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
    with pytest.warns(UserWarning, match="Missing a plate statement"):
        mcmc.run(random.PRNGKey(1))
Beispiel #11
0
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
Beispiel #12
0
def test_discrete_gibbs_bernoulli(random_walk, modified):
    def model():
        numpyro.sample("c", dist.Bernoulli(0.8))

    kernel = DiscreteHMCGibbs(NUTS(model), random_walk=random_walk, modified=modified)
    mcmc = MCMC(kernel, 1000, 200000, progress_bar=False)
    mcmc.run(random.PRNGKey(0))
    samples = mcmc.get_samples()["c"]
    assert_allclose(jnp.mean(samples), 0.8, atol=0.05)
Beispiel #13
0
def test_improper_uniform():
    def model():
        numpyro.sample("c", dist.Bernoulli(0.8))
        numpyro.sample(
            "u", dist.ImproperUniform(dist.constraints.unit_interval, (), ()))

    sampler = DiscreteHMCGibbs(NUTS(model))
    mcmc = MCMC(sampler, num_warmup=10, num_samples=10, progress_bar=False)
    mcmc.run(random.PRNGKey(0))
Beispiel #14
0
def test_model_with_lift_handler():
    def model(data):
        c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive)
        x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data)
        return x

    nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)}))
    mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10)
    mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,))))
Beispiel #15
0
def test_sites_have_unique_names():
    def model():
        alpha = numpyro.sample("alpha", dist.Normal())
        numpyro.deterministic("alpha", alpha * 2)

    mcmc = MCMC(NUTS(model), num_chains=1, num_samples=10, num_warmup=10)
    msg = "all sites must have unique names but got `alpha` duplicated"
    with pytest.raises(AssertionError, match=msg):
        mcmc.run(random.PRNGKey(0))
Beispiel #16
0
def test_random_module_mcmc(backend, init):

    if backend == "flax":
        import flax

        linear_module = flax.linen.Dense(features=1)
        bias_name = "bias"
        weight_name = "kernel"
        random_module = random_flax_module
        kwargs_name = "inputs"
    elif backend == "haiku":
        import haiku as hk

        linear_module = hk.transform(lambda x: hk.Linear(1)(x))
        bias_name = "linear.b"
        weight_name = "linear.w"
        random_module = random_haiku_module
        kwargs_name = "x"

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1.0, dim + 1.0)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    if init == "shape":
        kwargs = {"input_shape": (3,)}
    elif init == "kwargs":
        kwargs = {kwargs_name: data}

    def model(data, labels):
        nn = random_module(
            "nn",
            linear_module,
            {bias_name: dist.Cauchy(), weight_name: dist.Normal()},
            **kwargs
        )
        logits = nn(data).squeeze(-1)
        numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
    )
    mcmc.run(random.PRNGKey(2), data, labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert set(samples.keys()) == {
        "nn/{}".format(bias_name),
        "nn/{}".format(weight_name),
    }
    assert_allclose(
        np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0),
        true_coefs,
        atol=0.22,
    )
Beispiel #17
0
def test_init_strategy_substituted_model():
    def model():
        numpyro.sample("x", dist.Normal(0, 1))
        numpyro.sample("y", dist.Normal(0, 1))

    subs_model = numpyro.handlers.substitute(model, data={"x": 10.0})
    mcmc = MCMC(NUTS(subs_model), num_warmup=10, num_samples=10)
    with pytest.warns(UserWarning, match="skipping initialization"):
        mcmc.run(random.PRNGKey(1))
Beispiel #18
0
def benchmark_hmc(args, features, labels):
    step_size = np.sqrt(0.5 / features.shape[0])
    trajectory_length = step_size * args.num_steps
    rng_key = random.PRNGKey(1)
    start = time.time()
    kernel = NUTS(model, trajectory_length=trajectory_length)
    mcmc = MCMC(kernel, 0, args.num_samples)
    mcmc.run(rng_key, features, labels)
    print('\nMCMC elapsed time:', time.time() - start)
Beispiel #19
0
 def run_inference(self, args, rng_key, x_train, y_train, num_hidden):
     if args['num_chains'] > 1:
         rng_key = random.split(rng_key, args['num_chains'])
     kernel = NUTS(self.bnn_model)
     mcmc = MCMC(kernel,
                 args['num_warmup'],
                 args['num_samples'],
                 num_chains=args['num_chains'])
     mcmc.run(rng_key, x_train, y_train, num_hidden)
     return mcmc.get_samples()
Beispiel #20
0
def run_inference(dept, male, applications, admit, rng_key, args):
    kernel = NUTS(glmm)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, dept, male, applications, admit)
    return mcmc.get_samples()
Beispiel #21
0
def run_inference(model, capture_history, sex, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, capture_history, sex)
    mcmc.print_summary()
    return mcmc.get_samples()
Beispiel #22
0
def test_trivial_dirichlet(batch_shape):
    def model():
        x = numpyro.sample("x", dist.Dirichlet(jnp.ones(1)).expand(batch_shape))
        return numpyro.sample("y", dist.Normal(x, 1), obs=2)

    num_samples = 10
    mcmc = MCMC(NUTS(model), 10, num_samples)
    mcmc.run(random.PRNGKey(0))
    # because event_shape of x is (1,), x should only take value 1
    assert_allclose(mcmc.get_samples()["x"], jnp.ones((num_samples,) + batch_shape + (1,)))
Beispiel #23
0
 def __init__(self, model, data=None):
     self.data = data
     self.num_warmup = 1000
     self.num_samples = 2000
     self.num_chains = 4
     self.mcmc = MCMC(NUTS(model),
                      num_warmup=self.num_warmup,
                      num_samples=self.num_samples,
                      num_chains=self.num_chains)
     self.data = data
Beispiel #24
0
def run_inference(model, at_bats, hits, rng_key, args):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
Beispiel #25
0
def test_model_with_mask_false():
    def model():
        x = numpyro.sample("x", dist.Normal())
        with numpyro.handlers.mask(mask=False):
            numpyro.sample("y", dist.Normal(x), obs=1)

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1)
    mcmc.run(random.PRNGKey(1))
    assert_allclose(mcmc.get_samples()['x'].mean(), 0., atol=0.1)
Beispiel #26
0
def test_mcmc_parallel_chain(deterministic):
    GLOBAL["count"] = 0
    mcmc = MCMC(NUTS(model), 100, 100, num_chains=2)
    mcmc.run(random.PRNGKey(0), deterministic=deterministic)
    mcmc.get_samples()

    if deterministic:
        assert GLOBAL["count"] == 4
    else:
        assert GLOBAL["count"] == 3
Beispiel #27
0
def test_discrete_gibbs_multiple_sites():
    def model():
        numpyro.sample("x", dist.Bernoulli(0.7).expand([3]))
        numpyro.sample("y", dist.Binomial(10, 0.3))

    kernel = DiscreteHMCGibbs(NUTS(model))
    mcmc = MCMC(kernel, 1000, 10000, progress_bar=False)
    mcmc.run(random.PRNGKey(0))
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01)
    assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)
Beispiel #28
0
def run_inference(model, args, rng_key):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key)
    mcmc.print_summary(exclude_deterministic=False)
    return mcmc.get_samples()
Beispiel #29
0
def run_hmc(rng_key, model, data, num_mix_comp, args, bvm_init_locs):
    kernel = NUTS(model,
                  init_strategy=init_to_value(values=bvm_init_locs),
                  max_tree_depth=7)
    mcmc = MCMC(kernel,
                num_samples=args.num_samples,
                num_warmup=args.num_warmup)
    mcmc.run(rng_key, data, len(data), num_mix_comp)
    mcmc.print_summary()
    post_samples = mcmc.get_samples()
    return post_samples
Beispiel #30
0
def test_enum_subsample_smoke():
    def model(data):
        x = numpyro.sample("x", dist.Bernoulli(0.5))
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1):
            batch = numpyro.subsample(data, event_dim=0)
            numpyro.sample("obs", dist.Normal(x, 1), obs=batch)

    data = random.normal(random.PRNGKey(0), (10000, )) + 1
    kernel = HMCECS(NUTS(model), num_blocks=10)
    mcmc = MCMC(kernel, 10, 10)
    mcmc.run(random.PRNGKey(0), data)