def test_init_to_scalar_value(): def model(): numpyro.sample("x", dist.Normal(0, 1)) guide = AutoDiagonalNormal(model, init_loc_fn=init_to_value(values={"x": 1.0})) svi = SVI(model, guide, optim.Adam(1.0), Trace_ELBO()) svi.init(random.PRNGKey(0))
x = numpyro.sample("x", dist.Normal()) with numpyro.handlers.mask(mask=False): numpyro.sample("y", dist.Normal(x), obs=1) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1) mcmc.run(random.PRNGKey(1)) assert_allclose(mcmc.get_samples()['x'].mean(), 0., atol=0.1) @pytest.mark.parametrize('init_strategy', [ init_to_feasible(), init_to_median(num_samples=2), init_to_sample(), init_to_uniform(radius=3), init_to_value(values={'tau': 0.7}), init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value, ]) def test_initialize_model_change_point(init_strategy): def model(data): alpha = 1 / jnp.mean(data.astype(np.float32)) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = jnp.where( jnp.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
numpyro.sample("y", dist.Normal(x), obs=1) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1) mcmc.run(random.PRNGKey(1)) assert_allclose(mcmc.get_samples()["x"].mean(), 0.0, atol=0.15) @pytest.mark.parametrize( "init_strategy", [ init_to_feasible(), init_to_median(num_samples=2), init_to_sample(), init_to_uniform(radius=3), init_to_value(values={"tau": 0.7}), init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value, ], ) def test_initialize_model_change_point(init_strategy): def model(data): alpha = 1 / jnp.mean(data.astype(np.float32)) lambda1 = numpyro.sample("lambda1", dist.Exponential(alpha)) lambda2 = numpyro.sample("lambda2", dist.Exponential(alpha)) tau = numpyro.sample("tau", dist.Uniform(0, 1)) lambda12 = jnp.where( jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)