def assert_ok(model, guide, elbo, *args, **kwargs): """ Assert that inference works without warnings or errors. """ pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) for i in range(2): inference.step(*args, **kwargs)
def assert_error(model, guide, elbo, match=None): """ Assert that inference fails with an error. """ pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), match=match): inference.step()
def assert_warning(model, guide, elbo): """ Assert that inference works but with a warning. """ pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") inference.step() assert len(w), 'No warnings were raised' for warning in w: print(warning)
def main(args): # Define a basic model with a single Normal latent random variable `loc` # and a batch of Normally distributed observations. def model(data): loc = pyro.sample("loc", dist.Normal(0., 1.)) with pyro.plate("data", len(data), dim=-1): pyro.sample("obs", dist.Normal(loc, 1.), obs=data) # Define a guide (i.e. variational distribution) with a Normal # distribution over the latent random variable `loc`. def guide(data): guide_loc = pyro.param("guide_loc", torch.tensor(0.)) guide_scale = pyro.param("guide_scale", torch.tensor(1.), constraint=constraints.positive) pyro.sample("loc", dist.Normal(guide_loc, guide_scale)) # Generate some data. torch.manual_seed(0) data = torch.randn(100) + 3.0 # Because the API in minipyro matches that of Pyro proper, # training code works with generic Pyro implementations. with pyro_backend(args.backend), interpretation(monte_carlo): # Construct an SVI object so we can do variational inference on our # model/guide pair. Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO elbo = Elbo() adam = optim.Adam({"lr": args.learning_rate}) svi = infer.SVI(model, guide, adam, elbo) # Basic training loop pyro.get_param_store().clear() for step in range(args.num_steps): loss = svi.step(data) if args.verbose and step % 100 == 0: print("step {} loss = {}".format(step, loss)) # Report the final values of the variational parameters # in the guide after training. if args.verbose: for name in pyro.get_param_store(): value = pyro.param(name).data print("{} = {}".format(name, value.detach().cpu().numpy())) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
def test_elbo_plate_plate(backend, outer_dim, inner_dim): with pyro_backend(backend): pyro.get_param_store().clear() num_particles = 1 q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d) with context1: pyro.sample("x", d) with context2: pyro.sample("y", d) with context1, context2: pyro.sample("z", d) def guide(): d = dist.Categorical(pyro.param("q")) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": "parallel"}) with context1: pyro.sample("x", d, infer={"enumerate": "parallel"}) with context2: pyro.sample("y", d, infer={"enumerate": "parallel"}) with context1, context2: pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = kl_divergence( torch.distributions.Categorical(funsor.to_data(q)), torch.distributions.Categorical(funsor.to_data(p))) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [funsor.to_data(q)])[0] elbo = infer.TraceEnum_ELBO(num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=True) elbo = elbo.differentiable_loss if backend == "pyro" else elbo actual_loss = funsor.to_data(elbo(model, guide)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param('q')).grad assert_close(actual_loss, expected_loss, atol=1e-5) assert_close(actual_grad, expected_grad, atol=1e-5)
def test_optimizer(backend, optim_name, jit): def model(data): p = pyro.param("p", torch.tensor(0.5)) pyro.sample("x", dist.Bernoulli(p), obs=data) def guide(data): pass data = torch.tensor(0.) with pyro_backend(backend): pyro.get_param_store().clear() Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) optimizer = getattr(optim, optim_name)({"lr": 1e-6}) inference = infer.SVI(model, guide, optimizer, elbo) for i in range(2): inference.step(data)
def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data( actual_loss) assert_close(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() params = [] for name in names: params.append(funsor.to_data(pyro.param(name)).unconstrained()) actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True) expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True) for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads): if actual_grad is None or expected_grad is None: continue assert_close(actual_grad, expected_grad, atol=atol, rtol=rtol)
def build_svi(model, guide, elbo): pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) return infer.SVI(model, guide, adam, elbo)