def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss) assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() params = [] for name in names: params.append(funsor.to_data(pyro.param(name)).unconstrained()) actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True) expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True) for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads): if actual_grad is None or expected_grad is None: continue assert ops.allclose(actual_grad, expected_grad, atol=atol, rtol=rtol)
def test_elbo_plate_plate(backend, outer_dim, inner_dim): with pyro_backend(backend): pyro.get_param_store().clear() num_particles = 1 q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d) with context1: pyro.sample("x", d) with context2: pyro.sample("y", d) with context1, context2: pyro.sample("z", d) def guide(): d = dist.Categorical(pyro.param("q")) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": "parallel"}) with context1: pyro.sample("x", d, infer={"enumerate": "parallel"}) with context2: pyro.sample("y", d, infer={"enumerate": "parallel"}) with context1, context2: pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = kl_divergence( torch.distributions.Categorical(funsor.to_data(q)), torch.distributions.Categorical(funsor.to_data(p))) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [funsor.to_data(q)])[0] elbo = infer.TraceEnum_ELBO(num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=True) elbo = elbo.differentiable_loss if backend == "pyro" else elbo actual_loss = funsor.to_data(elbo(model, guide)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param('q')).grad assert ops.allclose(actual_loss, expected_loss, atol=1e-5) assert ops.allclose(actual_grad, expected_grad, atol=1e-5)
def test_rng_seed(backend): def model(): return pyro.sample("x", dist.Normal(0, 1)) with pyro_backend(backend): with handlers.seed(rng_seed=0): expected = model() with handlers.seed(rng_seed=0): actual = model() assert ops.allclose(actual, expected)
def test_local_param_ok(backend, jit): data = torch.randn(10) def model(): locs = pyro.param("locs", torch.tensor([-1., 0., 1.])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) def guide(): with pyro.plate("plate", len(data), dim=-1): p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1) pyro.sample("x", dist.Categorical(p)) return p with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo) # Check that pyro.param() can be called without init_value. expected = guide() actual = pyro.param("p") assert ops.allclose(actual, expected)