コード例 #1
0
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )

    mcmc.run(rng_key, X, Y)
    mcmc.print_summary(exclude_deterministic=False)

    samples = mcmc.get_samples()
    summary_dict = summary(samples, group_by_chain=False)

    print("\nMCMC elapsed time:", time.time() - start)

    return summary_dict
コード例 #2
0
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std)**2)

    init_params = np.array(0.)
    kernel = kernel_cls(potential_fn=potential_fn,
                        trajectory_length=8,
                        dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05)
    assert_allclose(np.std(hmc_states), true_std, rtol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert hmc_states.dtype == np.float64
コード例 #3
0
def test_predictive_with_improper():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(
                    dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"]
    assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05)
コード例 #4
0
def test_improper_normal(max_tree_depth):
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(
                    dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)
                ),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model, max_tree_depth=max_tree_depth)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
コード例 #5
0
ファイル: test_mcmc.py プロジェクト: pyro-ppl/numpyro
def test_reuse_mcmc_run(jit_args, shape):
    y1 = np.random.normal(3, 0.1, (100, ))
    y2 = np.random.normal(-3, 0.1, (shape, ))

    def model(y_obs):
        mu = numpyro.sample("mu", dist.Normal(0.0, 1.0))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.0))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel,
                num_warmup=300,
                num_samples=500,
                jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(32), y1)

    # Re-run on new data - should be much faster.
    mcmc.run(random.PRNGKey(32), y2)
    assert_allclose(mcmc.get_samples()["mu"].mean(), -3.0, atol=0.1)
コード例 #6
0
def test_predictive(parallel):
    model, data, true_probs = beta_bernoulli()
    mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    predictive = Predictive(model, samples, parallel=parallel)
    predictive_samples = predictive(random.PRNGKey(1))
    assert predictive_samples.keys() == {"beta_sq", "obs"}

    predictive.return_sites = ["beta", "beta_sq", "obs"]
    predictive_samples = predictive(random.PRNGKey(1))
    # check shapes
    assert predictive_samples["beta"].shape == (100, ) + true_probs.shape
    assert predictive_samples["beta_sq"].shape == (100, ) + true_probs.shape
    assert predictive_samples["obs"].shape == (100, ) + data.shape
    # check sample mean
    assert_allclose(
        predictive_samples["obs"].reshape((-1, ) + true_probs.shape).mean(0),
        true_probs,
        rtol=0.1)
コード例 #7
0
ファイル: gp.py プロジェクト: mjbajwa/numpyro
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    # demonstrate how to use different HMC initialization strategies
    if args.init_strategy == "value":
        init_strategy = init_to_value(values={"kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5})
    elif args.init_strategy == "median":
        init_strategy = init_to_median(num_samples=10)
    elif args.init_strategy == "feasible":
        init_strategy = init_to_feasible()
    elif args.init_strategy == "sample":
        init_strategy = init_to_sample()
    elif args.init_strategy == "uniform":
        init_strategy = init_to_uniform(radius=1)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, thinning=args.thinning,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
コード例 #8
0
ファイル: my_numpyro.py プロジェクト: yongduek/DBDA-python
class Sampler():
    def __init__(self, model, data=None):
        self.data = data
        self.num_warmup = 1000
        self.num_samples = 2000
        self.num_chains = 4
        self.mcmc = MCMC(NUTS(model),
                         num_warmup=self.num_warmup,
                         num_samples=self.num_samples,
                         num_chains=self.num_chains)
        self.data = data

    def fit(self, data):
        self.data = data
        self.mcmc.run(random.PRNGKey(0), **data)
        self.post = self.mcmc.get_samples()
        return self.post  # posterior samples

    def predict(self, data):
        pass
コード例 #9
0
ファイル: test_hmc_gibbs.py プロジェクト: hessammehr/numpyro
def test_discrete_gibbs_gmm_1d(modified, kernel, inner_kernel, kwargs):
    def model(probs, locs):
        c = numpyro.sample("c", dist.Categorical(probs))
        numpyro.sample("x", dist.Normal(locs[c], 0.5))

    probs = jnp.array([0.15, 0.3, 0.3, 0.25])
    locs = jnp.array([-2, 0, 2, 4])
    sampler = kernel(inner_kernel(model, trajectory_length=1.2),
                     modified=modified,
                     **kwargs)
    mcmc = MCMC(sampler,
                num_warmup=1000,
                num_samples=200000,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), probs, locs)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1)
    assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.4)
    assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1)
    assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
コード例 #10
0
def test_beta_bernoulli():
    from numpyro.contrib.tfp import distributions as dist

    warmup_steps, num_samples = (500, 2000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2))
    kernel = NUTS(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
コード例 #11
0
def test_bernoulli_latent_model():
    @config_enumerate
    def model(data):
        y_prob = numpyro.sample("y_prob", dist.Beta(1.0, 1.0))
        with numpyro.plate("data", data.shape[0]):
            y = numpyro.sample("y", dist.Bernoulli(y_prob))
            z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
            numpyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data)

    N = 2000
    y_prob = 0.3
    y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N, ))
    z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1))
    data = dist.Normal(2.0 * z, 1.0).sample(random.PRNGKey(2))

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
コード例 #12
0
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #13
0
def test_beta_bernoulli_x64(kernel_cls):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    kernel = kernel_cls(model=model, trajectory_length=1.)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
def sample(
    model,
    num_samples,
    num_warmup,
    num_chains=2,
    seed=0,
    chain_method="parallel",
    summary=True,
    **kwargs,
):
    """Run the No-U-Turn sampler

    Args:
        model: an NumPyro model function
        num_samples: number of samples to draw in each chain
        num_warmup: number of samples to use for tuning in each chain
        num_chains: number of chains to draw (default: {2})
        **kwargs: other arguments to be passed to the model function
        seed: random seed (default: {0})
        chain_method: one of NumPyro's sampling methods — "parallel" / "sequential" /
            "vectorized" (default: {"parallel"})
        summary: print diagnostics, including the Effective sample size and the
            Gelman-Rubin test (default: {True})

    Returns:
        mcmc: A fitted MCMC object
    """
    rng_key = random.PRNGKey(seed)
    kernel = NUTS(model)
    # Note: sampling more than one chain doesn't show a progress bar
    mcmc = MCMC(kernel,
                num_warmup,
                num_samples,
                num_chains,
                chain_method=chain_method)
    mcmc.run(rng_key, **kwargs)

    if summary:
        mcmc.print_summary()

    return mcmc
コード例 #15
0
def test_scan():
    def model(T=10, q=1, r=1, phi=0., beta=0.):

        def transition(state, i):
            x0, mu0 = state
            x1 = numpyro.sample('x', dist.Normal(phi * x0, q))
            mu1 = beta * mu0 + x1
            y1 = numpyro.sample('y', dist.Normal(mu1, r))
            numpyro.deterministic('y2', y1 * 2)
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
        y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

        _, xy = scan(transition, (x0, mu0), jnp.arange(T))
        x, y = xy

        return jnp.append(x0, x), jnp.append(y0, y)

    T = 10
    num_samples = 100
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 100, num_samples)
    mcmc.run(jax.random.PRNGKey(0), T=T)
    assert set(mcmc.get_samples()) == {'x', 'y', 'y2', 'x_0', 'y_0'}
    mcmc.print_summary()

    samples = mcmc.get_samples()
    x = samples.pop('x')[0]  # take 1 sample of x
    # this tests for the composition of condition and substitute
    # this also tests if we can use `vmap` for predictive.
    future = 5
    predictive = Predictive(numpyro.handlers.condition(model, {'x': x}),
                            samples, return_sites=['x', 'y', 'y2'], parallel=True)
    result = predictive(jax.random.PRNGKey(1), T=T + future)
    expected_shape = (num_samples, T + future)
    assert result['x'].shape == expected_shape
    assert result['y'].shape == expected_shape
    assert result['y2'].shape == expected_shape
    assert_allclose(result['x'][:, :T], jnp.broadcast_to(x, (num_samples, T)))
    assert_allclose(result['y'][:, :T], samples['y'])
コード例 #16
0
def test_binomial_stable_x64(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    warmup_steps, num_samples = 200, 200

    def model(data):
        p = numpyro.sample('p', dist.Beta(1., 1.))
        if with_logits:
            logits = logit(p)
            numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x'])
        else:
            numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])

    data = {'n': 5000000, 'x': 3849}
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p'].dtype == jnp.float64
コード例 #17
0
def main(args):
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words,
     unsupervised_words) = simulate_data(
         random.PRNGKey(1),
         num_categories=args.num_categories,
         num_words=args.num_words,
         num_supervised_data=args.num_supervised,
         num_unsupervised_data=args.num_unsupervised,
     )
    print('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
    mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words)
    samples = mcmc.get_samples()
    print('\nMCMC elapsed time:', time.time() - start)
    print_results(samples, transition_prob, emission_prob)
コード例 #18
0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
コード例 #19
0
ファイル: test_tfp.py プロジェクト: pyro-ppl/numpyro
def test_beta_bernoulli():
    from tensorflow_probability.substrates.jax import distributions as tfd

    num_warmup, num_samples = (500, 2000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample("p_latent", tfd.Beta(alpha, beta))
        numpyro.sample("obs", tfd.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = tfd.Bernoulli(true_probs).sample(seed=random.PRNGKey(1),
                                            sample_shape=(1000, 2))
    kernel = NUTS(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)
コード例 #20
0
ファイル: hmcecs.py プロジェクト: mjbajwa/numpyro
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
    svi_key, mcmc_key = random.split(hmcecs_key)

    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    guide = autoguide.AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    params, losses = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size)
    ref_params = {'theta': params['theta_auto_loc']}

    # taylor proxy estimates log likelihood (ll) by
    # taylor_expansion(ll, theta_curr) +
    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params
    proxy = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)
    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)

    mcmc.run(mcmc_key, data, obs, args.subsample_size)
    mcmc.print_summary()
    return losses, mcmc.get_samples()
コード例 #21
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    warmup_steps, num_samples = 1000, 8000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = kernel_cls(model=model, trajectory_length=8)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['coefs'].dtype == np.float64
コード例 #22
0
def test_mcmc_progbar():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 10, 10

    def model(data):
        mean = numpyro.param('mean', 0.)
        std = numpyro.param('std', 1., constraint=constraints.positive)
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(3), data)
    mcmc1 = MCMC(kernel, num_warmup, num_samples, progress_bar=False)
    mcmc1.run(random.PRNGKey(2), data)

    with pytest.raises(AssertionError):
        check_close(mcmc1.get_samples(),
                    mcmc.get_samples(),
                    atol=1e-4,
                    rtol=1e-4)
    mcmc1.warmup(random.PRNGKey(2), data)
    mcmc1.run(random.PRNGKey(3), data)
    check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
    check_close(mcmc1._warmup_state, mcmc._warmup_state, atol=1e-4, rtol=1e-4)
コード例 #23
0
    def test_inference_data_constant_data(self):
        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer import MCMC, NUTS

        x1 = 10
        x2 = 12
        y1 = np.random.randn(10)

        def model_constant_data(x, y1=None):
            _x = numpyro.sample("x", dist.Normal(1, 3))
            numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)

        nuts_kernel = NUTS(model_constant_data)
        mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
        mcmc.run(PRNGKey(0), x=x1, y1=y1)
        posterior = mcmc.get_samples()
        posterior_predictive = Predictive(model_constant_data,
                                          posterior)(PRNGKey(1), x1)
        predictions = Predictive(model_constant_data, posterior)(PRNGKey(2),
                                                                 x2)
        inference_data = from_numpyro(
            mcmc,
            posterior_predictive=posterior_predictive,
            predictions=predictions,
            constant_data={"x1": x1},
            predictions_constant_data={"x2": x2},
        )
        test_dict = {
            "posterior": ["x"],
            "posterior_predictive": ["y1"],
            "sample_stats": ["diverging"],
            "log_likelihood": ["y1"],
            "predictions": ["y1"],
            "observed_data": ["y1"],
            "constant_data": ["x1"],
            "predictions_constant_data": ["x2"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails
コード例 #24
0
def test_structured_mass():
    def model(cov):
        w = numpyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1))
        x = numpyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1))
        y = numpyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1))
        z = numpyro.sample("z", dist.Normal(0, 1000).expand([1]).to_event(1))
        wxyz = jnp.concatenate([w, x, y, z])
        numpyro.sample("obs",
                       dist.MultivariateNormal(jnp.zeros(5), cov),
                       obs=wxyz)

    w_cov = np.array([[1.5, 0.5], [0.5, 1.5]])
    xy_cov = np.array([[2.0, 1.0], [1.0, 3.0]])
    z_var = np.array([2.5])
    cov = np.zeros((5, 5))
    cov[:2, :2] = w_cov
    cov[2:4, 2:4] = xy_cov
    cov[4, 4] = z_var

    kernel = NUTS(model, dense_mass=[("w", ), ("x", "y")])
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1)
    mcmc.run(random.PRNGKey(1), cov)
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert_allclose(inverse_mass_matrix[("w", )], w_cov, atol=0.5, rtol=0.5)
    assert_allclose(inverse_mass_matrix[("x", "y")],
                    xy_cov,
                    atol=0.5,
                    rtol=0.5)
    assert_allclose(inverse_mass_matrix[("z", )], z_var, atol=0.5, rtol=0.5)

    kernel = NUTS(model, dense_mass=[("w", ), ("y", "x")])
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1)
    mcmc.run(random.PRNGKey(1), cov)
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert_allclose(inverse_mass_matrix[("w", )], w_cov, atol=0.5, rtol=0.5)
    assert_allclose(inverse_mass_matrix[("y", "x")],
                    xy_cov[::-1, ::-1],
                    atol=0.5,
                    rtol=0.5)
    assert_allclose(inverse_mass_matrix[("z", )], z_var, atol=0.5, rtol=0.5)
コード例 #25
0
def test_model_with_multiple_exec_paths(jit_args):
    def model(a=None, b=None, z=None):
        int_term = numpyro.sample("a", dist.Normal(0.0, 0.2))
        x_term, y_term = 0.0, 0.0
        if a is not None:
            x = numpyro.sample("x", dist.HalfNormal(0.5))
            x_term = a * x
        if b is not None:
            y = numpyro.sample("y", dist.HalfNormal(0.5))
            y_term = b * y
        sigma = numpyro.sample("sigma", dist.Exponential(1.0))
        mu = int_term + x_term + y_term
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=z)

    a = jnp.exp(np.random.randn(10))
    b = jnp.exp(np.random.randn(10))
    z = np.random.randn(10)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=20, num_samples=10, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(1), a, b=None, z=z)
    assert set(mcmc.get_samples()) == {"a", "x", "sigma"}
    mcmc.run(random.PRNGKey(2), a=None, b=b, z=z)
    assert set(mcmc.get_samples()) == {"a", "y", "sigma"}
    mcmc.run(random.PRNGKey(3), a=a, b=b, z=z)
    assert set(mcmc.get_samples()) == {"a", "x", "y", "sigma"}
コード例 #26
0
def test_mcmc_progbar():
    true_mean, true_std = 1.0, 2.0
    num_warmup, num_samples = 10, 10

    def model(data):
        mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False))
        std = numpyro.sample("std", dist.LogNormal(0, 1).mask(False))
        return numpyro.sample("obs", dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(3), data)
    mcmc1 = MCMC(kernel,
                 num_warmup=num_warmup,
                 num_samples=num_samples,
                 progress_bar=False)
    mcmc1.run(random.PRNGKey(2), data)

    with pytest.raises(AssertionError):
        check_close(mcmc1.get_samples(),
                    mcmc.get_samples(),
                    atol=1e-4,
                    rtol=1e-4)
    mcmc1.warmup(random.PRNGKey(2), data)
    mcmc1.run(random.PRNGKey(3), data)
    check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
    check_close(mcmc1.post_warmup_state,
                mcmc.post_warmup_state,
                atol=1e-4,
                rtol=1e-4)
コード例 #27
0
def test_compile_warmup_run(num_chains, chain_method, progress_bar):
    def model():
        numpyro.sample("x", dist.Normal(0, 1))

    if num_chains == 1 and chain_method in ['sequential', 'vectorized']:
        pytest.skip('duplicated test')
    if num_chains > 1 and chain_method == 'parallel':
        pytest.skip('duplicated test')

    rng_key = random.PRNGKey(0)
    num_samples = 10
    mcmc = MCMC(NUTS(model), 10, num_samples, num_chains,
                chain_method=chain_method, progress_bar=progress_bar)

    mcmc.run(rng_key)
    expected_samples = mcmc.get_samples()["x"]

    mcmc._compile(rng_key)
    # no delay after compiling
    mcmc.warmup(rng_key)
    mcmc.run(mcmc._warmup_state.rng_key)
    actual_samples = mcmc.get_samples()["x"]

    assert_allclose(actual_samples, expected_samples)

    # test for reproducible
    if num_chains > 1:
        mcmc = MCMC(NUTS(model), 10, num_samples, 1, progress_bar=progress_bar)
        rng_key = random.split(rng_key)[0]
        mcmc.run(rng_key)
        first_chain_samples = mcmc.get_samples()["x"]
        assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)
コード例 #28
0
def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(
        NUTS(model, dense_mass=True),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(PRNGKey(1), N=data.shape[0], y=data)
    mcmc.print_summary()

    # predict populations
    pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2),
                                                     data.shape[0])["y"]
    mu = jnp.mean(pop_pred, 0)
    pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0)
    plt.figure(figsize=(8, 6), constrained_layout=True)
    plt.plot(year,
             data[:, 0],
             "ko",
             mfc="none",
             ms=4,
             label="true hare",
             alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160),
                  xlabel="year",
                  ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")
コード例 #29
0
def test_gaussian_mixture_model():
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
        with numpyro.plate("num_clusters", K, dim=-1):
            cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.))
        with numpyro.plate("data", data.shape[0], dim=-1):
            assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions))
            numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data)

    true_cluster_means = jnp.array([1., 5., 10.])
    true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(random.PRNGKey(0), (N,))
    data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1))

    nuts_kernel = NUTS(gmm)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(samples["phi"].mean(0).sort(), true_mix_proportions, atol=0.05)
    assert_allclose(samples["cluster_means"].mean(0).sort(), true_cluster_means, atol=0.2)
コード例 #30
0
ファイル: test_tfp.py プロジェクト: gully/numpyro
def test_logistic_regression():
    from numpyro.contrib.tfp import distributions as dist

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    assert_allclose(jnp.mean(samples['coefs'], 0), true_coefs, atol=0.22)