def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) # set base value of x_guide is 0.9 x_base = 0.9 guide = substitute(guide, base_param_map={'x': x_base}) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) x, _ = x_guide.transform_with_intermediates(x_base) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def fit_svi(model, n_draws=1000, autoguide=AutoLaplaceApproximation, loss=Trace_ELBO(), optim=optim.Adam(step_size=.00001), num_warmup=2000, use_gpu=False, num_chains=1, progress_bar=False, sampler=None, **kwargs): select_device(use_gpu, num_chains) guide = autoguide(model) svi = SVI(model=model, guide=guide, loss=loss, optim=optim, **kwargs) # Experimental interface: svi_result = svi.run(jax.random.PRNGKey(0), num_steps=num_warmup, stable_update=True, progress_bar=progress_bar) # Old: post = guide.sample_posterior(jax.random.PRNGKey(1), params=svi_result.params, sample_shape=(1, n_draws)) # New: #predictive = Predictive(guide, params=svi_result.params, num_samples=n_draws) #post = predictive(jax.random.PRNGKey(1), **kwargs) # Old interface: # init_state = svi.init(jax.random.PRNGKey(0)) # state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(n_draws))#, length=num_warmup) # svi_params = svi.get_params(state) # post = guide.sample_posterior(jax.random.PRNGKey(1), svi_params, (1, n_draws)) trace = az.from_dict(post) return trace, post
def fit(self, X, Y, rng_key, n_step): self.X_train = X # store moments of training y (to normalize) self.y_mean = jnp.mean(Y) self.y_std = jnp.std(Y) # normalize y Y = (Y - self.y_mean) / self.y_std # setup optimizer and SVI optim = numpyro.optim.Adam(step_size=0.005, b1=0.5) svi = SVI( model, guide=AutoDelta(model), optim=optim, loss=Trace_ELBO(), X=X, Y=Y, ) params, _ = svi.run(rng_key, n_step) # get kernel parameters from guide with proper names self.kernel_params = svi.guide.median(params) # store cholesky factor of prior covariance self.L = linalg.cho_factor(self.kernel(X, X, **self.kernel_params)) # store inverted prior covariance multiplied by y self.alpha = linalg.cho_solve(self.L, Y) return self.kernel_params
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) params, losses = svi_result.params, svi_result.losses ref_params = {"theta": params["theta_auto_loc"]} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def run_svi(rng_key, X, Y, guide_family="AutoDiagonalNormal", K=8): assert guide_family in ["AutoDiagonalNormal", "AutoDAIS"] if guide_family == "AutoDAIS": guide = autoguide.AutoDAIS(model, K=K, eta_init=0.02, eta_max=0.5) step_size = 5e-4 elif guide_family == "AutoDiagonalNormal": guide = autoguide.AutoDiagonalNormal(model) step_size = 3e-3 optimizer = numpyro.optim.Adam(step_size=step_size) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(rng_key, args.num_svi_steps, X, Y) params = svi_result.params final_elbo = -Trace_ELBO(num_particles=1000).loss(rng_key, params, model, guide, X, Y) guide_name = guide_family if guide_family == "AutoDAIS": guide_name += "-{}".format(K) print("[{}] final elbo: {:.2f}".format(guide_name, final_elbo)) return guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(args.num_samples, ))
def test_obs_mask_ok(Elbo, mask, num_particles): data = np.array([7., 7., 7.]) def model(): x = numpyro.sample("x", dist.Normal(0., 1.)) with numpyro.plate("plate", len(data)): y = numpyro.sample("y", dist.Normal(x, 1.), obs=data, obs_mask=mask) if not_jax_tracer(y): assert ((y == data) == mask).all() def guide(): loc = numpyro.param("loc", np.zeros(())) scale = numpyro.param("scale", np.ones(()), constraint=constraints.positive) x = numpyro.sample("x", dist.Normal(loc, scale)) with numpyro.plate("plate", len(data)): with handlers.mask(mask=np.invert(mask)): numpyro.sample("y_unobserved", dist.Normal(x, 1.)) elbo = Elbo(num_particles=num_particles) svi = SVI(model, guide, numpyro.optim.Adam(1), elbo) svi_state = svi.init(random.PRNGKey(0)) svi.update(svi_state)
def test_obs_mask_multivariate_ok(Elbo, mask, num_particles): data = np.full((4, 3), 7.0) def model(): x = numpyro.sample("x", dist.MultivariateNormal(np.zeros(3), np.eye(3))) with numpyro.plate("plate", len(data)): y = numpyro.sample("y", dist.MultivariateNormal(x, np.eye(3)), obs=data, obs_mask=mask) if not_jax_tracer(y): assert ((y == data).all(-1) == mask).all() def guide(): loc = numpyro.param("loc", np.zeros(3)) cov = numpyro.param("cov", np.eye(3), constraint=constraints.positive_definite) x = numpyro.sample("x", dist.MultivariateNormal(loc, cov)) with numpyro.plate("plate", len(data)): with handlers.mask(mask=np.invert(mask)): numpyro.sample("y_unobserved", dist.MultivariateNormal(x, np.eye(3))) elbo = Elbo(num_particles=num_particles) svi = SVI(model, guide, numpyro.optim.Adam(1), elbo) svi_state = svi.init(random.PRNGKey(0)) svi.update(svi_state)
def find_map( self, num_steps: int = 10000, handlers: Optional[list] = None, reparam: Union[str, hdl.reparam] = "auto", svi_kwargs: dict = {}, ): """EXPERIMENTAL: find MAP. Args: num_steps (int): [description]. Defaults to 10000. handlers (list, optional): [description]. Defaults to None. reparam (str, or numpyro.handlers.reparam): [description]. Defaults to 'auto'. svi_kwargs (dict): [description]. Defaults to {}. """ model = self._add_handlers_to_model(handlers=handlers, reparam=reparam) guide = numpyro.infer.autoguide.AutoDelta(model) optim = svi_kwargs.pop("optim", numpyro.optim.Minimize()) loss = svi_kwargs.pop("loss", numpyro.infer.Trace_ELBO()) map_svi = SVI(model, guide, optim, loss=loss, **svi_kwargs) rng_key, self._rng_key = random.split(self._rng_key) map_result = map_svi.run(rng_key, num_steps, self.n, nu=self.nu, nu_err=self.nu_err) self._map_loss = map_result.losses self._map_guide = map_svi.guide self._map_params = map_result.params
def test_init_to_scalar_value(): def model(): numpyro.sample("x", dist.Normal(0, 1)) guide = AutoDiagonalNormal(model, init_loc_fn=init_to_value(values={"x": 1.0})) svi = SVI(model, guide, optim.Adam(1.0), Trace_ELBO()) svi.init(random.PRNGKey(0))
def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optax.adam(0.05) svi = SVI(model, guide, adam, elbo) svi_state = svi.init(random.PRNGKey(1), data) assert_allclose( svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), 0.8, atol=0.05, rtol=0.05, )
def main(args): encoder_nn = encoder(args.hidden_dim, args.z_dim) decoder_nn = decoder(args.hidden_dim, 28 * 28) adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim) rng_key = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3) sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0]) svi_state = svi.init(rng_key_init, sample_batch) @jit def epoch_train(svi_state, rng_key): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0., svi_state)) @jit def eval_test(svi_state, rng_key): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.) loss = loss / num_test return loss def reconstruct_img(epoch, rng_key): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') rng_key_binarize, rng_key_sample = random.split(rng_key) test_sample = binarize(rng_key_binarize, img) params = svi.get_params(svi_state) z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng_key_sample) img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_key, 4) t_start = time.time() num_train, train_idx = train_init() _, svi_state = epoch_train(svi_state, rng_key_train) rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3) num_test, test_idx = test_init() test_loss = eval_test(svi_state, rng_key_test) reconstruct_img(i, rng_key_reconstruct) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def fit_advi(model, num_iter, learning_rate=0.01, seed=0): """Automatic Differentiation Variational Inference using a Normal variational distribution with a diagonal covariance matrix. Args: model: a NumPyro's model function num_iter: number of iterations of gradient descent (Adam) learning_rate: the step size for the Adam algorithm (default: {0.01}) seed: random seed (default: {0}) Returns: a set of results of type ADVIResults """ rng_key = random.PRNGKey(seed) adam = Adam(learning_rate) # Automatically create a variational distribution (aka "guide" in Pyro's terminology) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key) # Run optimization last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(num_iter)) results = ADVIResults(svi=svi, guide=guide, state=last_state, losses=losses) return results
def test_beta_bernoulli(auto_class): data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample('beta', dist.Beta(np.ones(2), np.ones(2))) numpyro.sample('obs', dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['beta'], 0), true_coefs, atol=0.04)
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data, labels) def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,)) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_predictive_with_guide(): data = jnp.array([1] * 8 + [0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) with numpyro.plate("plate", 10): numpyro.deterministic("beta_sq", f**2) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) predictive = Predictive(model, guide=guide, params=params, num_samples=1000)(random.PRNGKey(2), data=None) assert predictive["beta_sq"].shape == (1000, ) obs_pred = predictive["obs"].astype(np.float32) assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(num_particles=100)) svi_state = svi.init(PRNGKey(0), data) # Training loop def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, args.num_steps, body_fn, svi_state) # Report the final values of the variational parameters # in the guide after training. params = svi.get_params(svi_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def test_autoguide_deterministic(auto_class): def model(y=None): n = y.size if y is not None else 1 mu = numpyro.sample("mu", dist.Normal(0, 5)) sigma = numpyro.param("sigma", 1, constraint=constraints.positive) y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y) numpyro.deterministic("z", (y - mu) / sigma) mu, sigma = 2, 3 y = mu + sigma * random.normal(random.PRNGKey(0), shape=(300,)) y_train = y[:200] y_test = y[200:] guide = auto_class(model) optimiser = numpyro.optim.Adam(step_size=0.01) svi = SVI(model, guide, optimiser, Trace_ELBO()) params, losses = svi.run(random.PRNGKey(0), num_steps=500, y=y_train) posterior_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(1000,) ) predictive = Predictive(model, posterior_samples, params=params) predictive_samples = predictive(random.PRNGKey(0), y_test) assert predictive_samples["y"].shape == (1000, 100) assert predictive_samples["z"].shape == (1000, 100) assert_allclose( (predictive_samples["y"] - posterior_samples["mu"][..., None]) / params["sigma"], predictive_samples["z"], atol=0.05, )
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2))) numpyro.sample("obs", dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior( random.PRNGKey(1), params, sample_shape=(1000,) ) assert_allclose(jnp.mean(posterior_samples["beta"], 0), true_coefs, atol=0.05) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, 2)
def test_run(progress_bar): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param( "alpha_q", lambda key: random.normal(key), constraint=constraints.positive ) beta_q = numpyro.param( "beta_q", lambda key: random.exponential(key), constraint=constraints.positive, ) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO()) params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) assert losses.shape == (1000,) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), 0.8, atol=0.05, rtol=0.05, )
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha) ), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median["loc"], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median["loc"][1], true_coef, rtol=0.1)
def __init__( self, model: Model, guide: Guide, loss: Trace_ELBO = Trace_ELBO(num_particles=1), optimizer: optim.optimizers.optimizer = optim.ClippedAdam, lr: float = 0.001, lrd: float = 1.0, rng_key: int = 254, num_epochs: int = 30000, num_samples: int = 1000, log_func=_print_consumer, log_freq=1000, to_numpy: bool = True, ): self.model = model self.guide = guide self.loss = loss self.optimizer = optimizer(step_size=lambda x: lr * lrd**x) self.rng_key = random.PRNGKey(rng_key) self.svi = SVI(self.model, self.guide, self.optimizer, loss=self.loss) self.init_state = None self.log_func = log_func self.log_freq = log_freq self.num_epochs = num_epochs self.num_samples = num_samples self.loss = None self.to_numpy = to_numpy
def test_elbo_dynamic_support(): x_prior = dist.Uniform(0, 5) x_unconstrained = 2. def model(): numpyro.sample('x', x_prior) class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_unconstrained})(*args, **kwargs) adam = optim.Adam(0.01) guide = _AutoGuide(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) guide_log_prob = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum() transfrom = transforms.biject_to(constraints.interval(0, 5)) x = transfrom(x_unconstrained) logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x) model_log_prob = x_prior.log_prob(x) + logdet expected_loss = guide_log_prob - model_log_prob assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def __init__( self, model: Model, guide: Guide, loss: Trace_ELBO = Trace_ELBO(num_particles=1), optimizer: optim.optimizers.optimizer = optim.Adam, lr: float = 0.001, rng_key: int = 254, num_epochs: int = 100000, num_samples: int = 5000, log_func=print, log_freq=0, ): self.model = model self.guide = guide self.loss = loss self.optimizer = optimizer(step_size=lr) self.rng_key = random.PRNGKey(rng_key) self.svi = SVI(self.model, self.guide, self.optimizer, loss=self.loss) self.init_state = None self.log_func = log_func self.log_freq = log_freq self.num_epochs = num_epochs self.num_samples = num_samples self.loss = None
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def test_mutable_state(stable_update, num_particles, elbo): def model(): x = numpyro.sample("x", dist.Normal(-1, 1)) numpyro_mutable("x1p", x + 1) def guide(): loc = numpyro.param("loc", 0.0) p = numpyro_mutable("loc1p", {"value": None}) # we can modify the content of `p` if it is a dict p["value"] = loc + 2 numpyro.sample("x", dist.Normal(loc, 0.1)) svi = SVI(model, guide, optim.Adam(0.1), elbo(num_particles=num_particles)) if num_particles > 1: with pytest.raises(ValueError, match="mutable state"): svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) return svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) params = svi_result.params mutable_state = svi_result.state.mutable_state assert set(mutable_state) == {"x1p", "loc1p"} assert_allclose(mutable_state["loc1p"]["value"], params["loc"] + 2, atol=0.1) # here, the initial loc has value 0., hence x1p will have init value near 1 # it won't be updated during SVI run because it is not a mutable state assert_allclose(mutable_state["x1p"], 1.0, atol=0.2)
def test_predictive_with_guide(): data = jnp.array([1] * 8 + [0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) with numpyro.plate("plate", 10): numpyro.deterministic("beta_sq", f**2) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(1), 3000, data) params = svi_result.params predictive = Predictive(model, guide=guide, params=params, num_samples=1000)(random.PRNGKey(2), data=None) assert predictive["beta_sq"].shape == (1000, ) obs_pred = predictive["obs"].astype(np.float32) assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
def run_inference(model, inputs, method=None): if method is None: # NUTS num_samples = 5000 logger.info('NUTS sampling') kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples) rng_key = random.PRNGKey(0) mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', )) logger.info(r'MCMC summary for: {}'.format(model.__name__)) mcmc.print_summary(exclude_deterministic=False) samples = mcmc.get_samples() else: #SVI logger.info('Guide generation...') rng_key = random.PRNGKey(0) guide = AutoDiagonalNormal(model=model) logger.info('Optimizer generation...') optim = Adam(0.05) logger.info('SVI generation...') svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs) init_state = svi.init(rng_key) logger.info('Scan...') state, loss = lax.scan(lambda x, i: svi.update(x), init_state, np.zeros(2000)) params = svi.get_params(state) samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, )) logger.info(r'SVI summary for: {}'.format(model.__name__)) numpyro.diagnostics.print_summary(samples, prob=0.90, group_by_chain=False) return samples
def test_run_with_small_num_steps(num_steps): def model(): pass def guide(): pass svi = SVI(model, guide, optim.Adam(1), Trace_ELBO()) svi.run(random.PRNGKey(0), num_steps)
def test_autocontinuous_local_error(): def model(): with numpyro.plate("N", 10, subsample_size=4): numpyro.sample("x", dist.Normal(0, 1)) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(1.0), Trace_ELBO()) with pytest.raises(ValueError, match="local latent variables"): svi.init(random.PRNGKey(0))
def test_subsample_guide(auto_class): # The model adapted from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5)) with handlers.substitute(data={"data": subsample}): plate = numpyro.plate("data", full_size, subsample_size=len(subsample)) assert plate.size == 50 def transition_fn(z_prev, y_curr): with plate: z_curr = numpyro.sample("state", dist.Normal(z_prev, drift)) y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr) return z_curr, y_curr _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps) return result def create_plates(batch, subsample, full_size): with handlers.substitute(data={"data": subsample}): return numpyro.plate("data", full_size, subsample_size=subsample.shape[0]) guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 with handlers.seed(rng_seed=0): data = model(None, jnp.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) svi = SVI(model, guide, optim.Adam(0.02), Trace_ELBO()) svi_state = svi.init( random.PRNGKey(0), data[:, :batch_size], jnp.arange(batch_size), full_size=full_size, ) update_fn = jit(svi.update, static_argnums=(3, )) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = jnp.arange(beg, end) batch = data[:, beg:end] beg = end svi_state, loss = update_fn(svi_state, batch, subsample, full_size)