Esempio n. 1
0
def test_beta_bernoulli():
    from numpyro.contrib.tfp import distributions as dist

    warmup_steps, num_samples = (500, 2000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1),
                                      sample_shape=(1000, 2))
    kernel = NUTS(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
Esempio n. 2
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    warmup_steps, num_samples = 1000, 8000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = kernel_cls(model=model, trajectory_length=8)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['coefs'].dtype == np.float64
Esempio n. 3
0
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
    svi_key, mcmc_key = random.split(hmcecs_key)

    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    guide = autoguide.AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    params, losses = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size)
    ref_params = {'theta': params['theta_auto_loc']}

    # taylor proxy estimates log likelihood (ll) by
    # taylor_expansion(ll, theta_curr) +
    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params
    proxy = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)
    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)

    mcmc.run(mcmc_key, data, obs, args.subsample_size)
    mcmc.print_summary()
    return losses, mcmc.get_samples()
Esempio n. 4
0
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std) ** 2)

    init_params = np.array(0.)
    if kernel_cls is SA:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05)
    assert_allclose(np.std(hmc_states), true_std, rtol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert hmc_states.dtype == np.float64
Esempio n. 5
0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = np.dot(a, a.T)
    true_prec = np.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * np.dot(z.T, np.dot(true_prec, z))

    init_params = np.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples), true_mean, atol=0.02)
    assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
Esempio n. 6
0
def test_binomial_stable_x64(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    warmup_steps, num_samples = 200, 200

    def model(data):
        p = numpyro.sample('p', dist.Beta(1., 1.))
        if with_logits:
            logits = logit(p)
            numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x'])
        else:
            numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])

    data = {'n': 5000000, 'x': 3849}
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p'].dtype == jnp.float64
Esempio n. 7
0
    def test_inference_data_constant_data(self):
        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer import MCMC, NUTS

        x1 = 10
        x2 = 12
        y1 = np.random.randn(10)

        def model_constant_data(x, y1=None):
            _x = numpyro.sample("x", dist.Normal(1, 3))
            numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)

        nuts_kernel = NUTS(model_constant_data)
        mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
        mcmc.run(PRNGKey(0), x=x1, y1=y1)
        posterior = mcmc.get_samples()
        posterior_predictive = Predictive(model_constant_data,
                                          posterior)(PRNGKey(1), x1)
        predictions = Predictive(model_constant_data, posterior)(PRNGKey(2),
                                                                 x2)
        inference_data = from_numpyro(
            mcmc,
            posterior_predictive=posterior_predictive,
            predictions=predictions,
            constant_data={"x1": x1},
            predictions_constant_data={"x2": x2},
        )
        test_dict = {
            "posterior": ["x"],
            "posterior_predictive": ["y1"],
            "sample_stats": ["diverging"],
            "log_likelihood": ["y1"],
            "predictions": ["y1"],
            "observed_data": ["y1"],
            "constant_data": ["x1"],
            "predictions_constant_data": ["x2"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails
Esempio n. 8
0
def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(
        NUTS(model, dense_mass=True),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(PRNGKey(1), N=data.shape[0], y=data)
    mcmc.print_summary()

    # predict populations
    pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2),
                                                     data.shape[0])["y"]
    mu = jnp.mean(pop_pred, 0)
    pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0)
    plt.figure(figsize=(8, 6), constrained_layout=True)
    plt.plot(year,
             data[:, 0],
             "ko",
             mfc="none",
             ms=4,
             label="true hare",
             alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160),
                  xlabel="year",
                  ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")
Esempio n. 9
0
def test_beta_bernoulli():
    from tensorflow_probability.substrates.jax import distributions as tfd

    num_warmup, num_samples = (500, 2000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample("p_latent", tfd.Beta(alpha, beta))
        numpyro.sample("obs", tfd.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = tfd.Bernoulli(true_probs).sample(
        seed=random.PRNGKey(1), sample_shape=(1000, 2)
    )
    kernel = NUTS(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)
Esempio n. 10
0
def test_logistic_regression():
    from numpyro.contrib.tfp import distributions as dist

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    assert_allclose(jnp.mean(samples['coefs'], 0), true_coefs, atol=0.22)
Esempio n. 11
0
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
        numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    if kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(model, trajectory_length=1.0, dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["p_latent"].dtype == jnp.float64
Esempio n. 12
0
def test_gaussian_mixture_model():
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
        with numpyro.plate("num_clusters", K, dim=-1):
            cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.))
        with numpyro.plate("data", data.shape[0], dim=-1):
            assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions))
            numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data)

    true_cluster_means = jnp.array([1., 5., 10.])
    true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(random.PRNGKey(0), (N,))
    data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1))

    nuts_kernel = NUTS(gmm)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(samples["phi"].mean(0).sort(), true_mix_proportions, atol=0.05)
    assert_allclose(samples["cluster_means"].mean(0).sort(), true_cluster_means, atol=0.2)
Esempio n. 13
0
    def run_inference(self,
                      model,
                      rng_key,
                      X,
                      Y,
                      hypers,
                      num_warmup=500,
                      num_chains=1,
                      num_samples=1000):

        start = time.time()
        kernel = NUTS(model)
        mcmc = MCMC(kernel,
                    num_warmup,
                    num_samples,
                    num_chains=num_chains,
                    progress_bar=False
                    if "NUMPYRO_SPHINXBUILD" in os.environ else True)
        mcmc.run(rng_key, X, Y, hypers)
        mcmc.print_summary()
        print('\nMCMC elapsed time:', time.time() - start)
        return mcmc.get_samples()
Esempio n. 14
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    if kernel_cls is SA:
        num_warmup, num_samples = (100000, 100000)
    elif kernel_cls is BarkerMH:
        num_warmup, num_samples = (2000, 12000)
    else:
        num_warmup, num_samples = (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model)
    else:
        kernel = kernel_cls(
            model=model, trajectory_length=8, find_heuristic_step_size=True
        )
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
    )
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples["logits"].shape == (num_samples, N)
    # those coefficients are found by doing MAP inference using AutoDelta
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.1)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["coefs"].dtype == jnp.float64
Esempio n. 15
0
def main(m, seed):
    rng_key = random.PRNGKey(seed)
    cutoff_up = 800
    cutoff_down = 100
    
    if m <= 10:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=m)
    else:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=10, nu_min=m-10)
    
    model = generative_model

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
    seq = (seq['beliefs'][0][cutoff_down:cutoff_up], 
           seq['beliefs'][1][cutoff_down:cutoff_up])
    
    rng_key, _rng_key = random.split(rng_key)
    mcmc.run(
        _rng_key, 
        seq, 
        y=responses_data[cutoff_down:cutoff_up], 
        mask=mask_data[cutoff_down:cutoff_up].astype(bool), 
        extra_fields=('potential_energy',)
    )
    
    samples = mcmc.get_samples()    
    waic = log_pred_density(
        model, 
        samples, 
        seq, 
        y=responses_data[cutoff_down:cutoff_up], 
        mask=mask_data[cutoff_down:cutoff_up].astype(bool)
    )['waic']

    jnp.savez('fit_waic_sample/dyn_fit_waic_sample_minf{}.npz'.format(m), samples=samples, waic=waic)

    print(mcmc.get_extra_fields()['potential_energy'].mean())
Esempio n. 16
0
def test_neals_funnel_smoke():
    dim = 10

    guide = AutoIAFNormal(neals_funnel)
    svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0), dim)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, dim)
        return svi_state

    svi_state = lax.fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)

    neutra = NeuTraReparam(guide, params)
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_warmup=50, num_samples=50)
    mcmc.run(random.PRNGKey(1), dim)
    samples = mcmc.get_samples()
    transformed_samples = neutra.transform_sample(samples['auto_shared_latent'])
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
Esempio n. 17
0
def test_binomial_stable_x64(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    num_warmup, num_samples = 200, 200

    def model(data):
        p = numpyro.sample("p", dist.Beta(1.0, 1.0))
        if with_logits:
            logits = logit(p)
            numpyro.sample(
                "obs", dist.Binomial(data["n"], logits=logits), obs=data["x"]
            )
        else:
            numpyro.sample("obs", dist.Binomial(data["n"], probs=p), obs=data["x"])

    data = {"n": 5000000, "x": 3849}
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p"], 0), data["x"] / data["n"], rtol=0.05)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["p"].dtype == jnp.float64
Esempio n. 18
0
def test_logistic_regression():
    from tensorflow_probability.substrates.jax import distributions as tfd

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = tfd.Bernoulli(logits=logits).sample(seed=random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample("coefs", tfd.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
        return numpyro.sample("obs", tfd.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples["logits"].shape == (num_samples, N)
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.22)
Esempio n. 19
0
def test_beta_bernoulli_x64(kernel_cls):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    kernel = kernel_cls(model=model, trajectory_length=1.)
    mcmc = MCMC(kernel,
                num_warmup=warmup_steps,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
Esempio n. 20
0
def main(args):
    is_sphinxbuild = "NUMPYRO_SPHINXBUILD" in os.environ
    data = load_data()
    data_dict = make_birthdays_data_dict(data)
    mcmc = MCMC(
        NUTS(birthdays_model, init_strategy=init_to_median),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=(not is_sphinxbuild),
    )
    mcmc.run(jax.random.PRNGKey(0), **data_dict)
    if not is_sphinxbuild:
        mcmc.print_summary()

    if args.save_figure:
        samples = mcmc.get_samples()
        print(f"Saving figure at {args.save_figure}")
        fig = make_figure(data, samples)
        fig.savefig(args.save_figure)
        plt.close()

    return mcmc
Esempio n. 21
0
def test_random_module_mcmc(backend):
    if backend == "flax":
        import flax

        linear_module = flax.nn.Dense.partial(features=1)
        bias_name = "bias"
        weight_name = "kernel"
        random_module = random_flax_module
    elif backend == "haiku":
        import haiku as hk

        linear_module = hk.transform(lambda x: hk.Linear(1)(x))
        bias_name = "linear.b"
        weight_name = "linear.w"
        random_module = random_haiku_module

    def model(data, labels):
        nn = random_module("nn", linear_module,
                           prior={bias_name: dist.Cauchy(), weight_name: dist.Normal()},
                           input_shape=(dim,))
        logits = nn(data).squeeze(-1)
        numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)

    N, dim = 3000, 3
    warmup_steps, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data, labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert set(samples.keys()) == {"nn/{}".format(bias_name), "nn/{}".format(weight_name)}
    assert_allclose(np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22)
Esempio n. 22
0
def main(args):
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words, unsupervised_words) = simulate_data(
        random.PRNGKey(1),
        num_categories=args.num_categories,
        num_words=args.num_words,
        num_supervised_data=args.num_supervised,
        num_unsupervised_data=args.num_unsupervised,
    )
    print('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words)
    samples = mcmc.get_samples()
    print_results(samples, transition_prob, emission_prob)
    print('\nMCMC elapsed time:', time.time() - start)

    # make plots
    fig, ax = plt.subplots(1, 1)

    x = onp.linspace(0, 1, 101)
    for i in range(transition_prob.shape[0]):
        for j in range(transition_prob.shape[1]):
            ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x),
                    label="trans_prob[{}, {}], true value = {:.2f}"
                    .format(i, j, transition_prob[i, j]))
    ax.set(xlabel="Probability", ylabel="Frequency",
           title="Transition probability posterior")
    ax.legend()

    plt.savefig("hmm_plot.pdf")
    plt.tight_layout()
Esempio n. 23
0
def test_dense_mass(kernel_cls, rho):
    num_warmup, num_samples = 20000, 10000

    true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

    def model():
        numpyro.sample(
            "x",
            dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov))

    if kernel_cls is HMC or kernel_cls is NUTS:
        kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model, dense_mass=True)

    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0))

    mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt
    if kernel_cls is HMC or kernel_cls is NUTS:
        mass_matrix_sqrt = mass_matrix_sqrt[("x", )]
    mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt))
    estimated_cov = jnp.linalg.inv(mass_matrix)
    assert_allclose(estimated_cov, true_cov, rtol=0.10)

    samples = mcmc.get_samples()["x"]
    assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50)
    assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05)
    assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]),
                    jnp.array(rho),
                    atol=0.20)
    assert_allclose(jnp.var(samples, axis=0),
                    jnp.array([10.0, 0.1]),
                    rtol=0.20)
Esempio n. 24
0
    def inference(belief_sequences, obs, mask, rng_key):
        nuts_kernel = NUTS(model, dense_mass=True)
        mcmc = MCMC(
            nuts_kernel, 
            num_warmup=num_warmup, 
            num_samples=num_samples, 
            num_chains=num_chains,
            chain_method="vectorized",
            progress_bar=False
        )

        mcmc.run(
            rng_key, 
            belief_sequences, 
            obs, 
            mask, 
            extra_fields=('potential_energy',)
        )

        samples = mcmc.get_samples()
        potential_energy = mcmc.get_extra_fields()['potential_energy'].mean()
        # mcmc.print_summary()

        return samples, potential_energy
Esempio n. 25
0
class NutsHandler(Handler):
    def __init__(
        self,
        model,
        num_warmup=2000,
        num_samples=10000,
        num_chains=1,
        rng_key=0,
        to_numpy: bool = True,
        *args,
        **kwargs,
    ):
        self.model = model
        self.num_warmup = num_warmup
        self.num_samples = num_samples
        self.num_chains = num_chains
        self.rng_key, self.rng_key_ = random.split(random.PRNGKey(rng_key))
        self.to_numpy = to_numpy
        self.kernel = NUTS(model, **kwargs)
        self.mcmc = MCMC(self.kernel,
                         num_warmup,
                         num_samples,
                         num_chains=num_chains)

    def predict(self, *args, **kwargs):
        predictive = Predictive(self.model, self.posterior.data, **kwargs)
        self.predictive = Posterior(predictive(self.rng_key_, *args))

    def fit(self, *args, **kwargs):
        self.num_samples = kwargs.get("num_samples", self.num_samples)

        self.mcmc.run(self.rng_key_, *args, **kwargs)
        self.posterior = Posterior(self.mcmc.get_samples(), self.to_numpy)

    def summary(self, *args, **kwargs):
        self.mcmc.print_summary(*args, **kwargs)
Esempio n. 26
0
File: base.py Progetto: elray1/covid
    def infer(self,
              num_warmup=1000,
              num_samples=1000,
              num_chains=1,
              rng_key=PRNGKey(1),
              **args):
        '''Fit using MCMC'''

        args = dict(self.args, **args)

        kernel = NUTS(self, init_strategy=numpyro.infer.util.init_to_median())

        mcmc = MCMC(kernel,
                    num_warmup=num_warmup,
                    num_samples=num_samples,
                    num_chains=num_chains)

        mcmc.run(rng_key, **self.obs, **args)
        mcmc.print_summary()

        self.mcmc = mcmc
        self.mcmc_samples = mcmc.get_samples()

        return self.mcmc_samples
Esempio n. 27
0
def test_estimate_likelihood(kernel_cls):
    data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = .1
    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (10_000, ))
    n, _ = data.shape
    num_warmup = 200
    num_samples = 200
    num_blocks = 20

    def model(data):
        mean = numpyro.sample(
            'mean', dist.Normal(ref_params, jnp.ones_like(ref_params)))
        with numpyro.plate('N', data.shape[0], subsample_size=100,
                           dim=-2) as idx:
            numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx])

    proxy_fn = HMCECS.taylor_proxy({'mean': ref_params})
    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks)
    mcmc = MCMC(kernel, num_warmup, num_samples)

    mcmc.run(random.PRNGKey(0),
             data,
             extra_fields=['hmc_state.potential_energy'])

    pes = mcmc.get_extra_fields()['hmc_state.potential_energy']
    samples = mcmc.get_samples()
    pes_full = vmap(lambda sample: log_density(model, (data, ), {}, {
        **sample,
        **{
            'N': jnp.arange(n)
        }
    })[0])(samples)

    assert jnp.var(jnp.exp(-pes - pes_full)) < 1.
Esempio n. 28
0
def sample_posterior_gibbs(rng_key: random.PRNGKey,
                           model,
                           data: np.ndarray,
                           Nsamples: int = 1000,
                           alpha: float = 1,
                           sigma: float = 0,
                           T: int = 10,
                           gibbs_fn=None,
                           gibbs_sites=None):
    assert gibbs_fn is not None
    assert gibbs_sites is not None

    Npoints = len(data)

    inner_kernel = NUTS(model)
    kernel = HMCGibbs(inner_kernel, gibbs_fn=gibbs_fn, gibbs_sites=gibbs_sites)
    mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP)
    mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)
    samples = mcmc.get_samples()

    z = samples['z']
    assert z.shape == (Nsamples, Npoints)

    return z
Esempio n. 29
0
def run_inference(
    model: Callable,
    at_bats: jnp.ndarray,
    hits: jnp.ndarray,
    rng_key: jnp.ndarray,
    *,
    num_warmup: int = 1500,
    num_samples: int = 3000,
    num_chains: int = 1,
    algo_name: str = "NUTS",
) -> Dict[str, jnp.ndarray]:

    if algo_name == "NUTS":
        kernel = NUTS(model)
    elif algo_name == "HMC":
        kernel = HMC(model)
    elif algo_name == "SA":
        kernel = SA(model)
    else:
        raise ValueError("Unknown algorithm name")

    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
Esempio n. 30
0
def main() -> None:

    df = load_dataset()
    test_index = 80
    test_len = len(df) - test_index
    y_train = jnp.array(df.loc[:test_index, "value"], dtype=jnp.float32)

    # Inference
    kernel = NUTS(sgt)
    mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1)
    mcmc.run(random.PRNGKey(0), y_train, seasonality=38)
    mcmc.print_summary()
    posterior_samples = mcmc.get_samples()

    # Prediction
    predictive = Predictive(sgt,
                            posterior_samples,
                            return_sites=["y_forecast"])
    posterior_predictive = predictive(random.PRNGKey(1),
                                      y_train,
                                      seasonality=38,
                                      future=test_len)

    root = pathlib.Path("./data/time_series")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)
    plot_results(
        df["time"].values,
        df["value"].values,
        posterior_samples,
        posterior_predictive,
        test_index,
        root,
    )