def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): """ Assert that enumeration runs... """ with pyro_backend("pyro"): pyro.clear_param_store() if guide is None: guide = lambda **kwargs: None # noqa: E731 q_pyro, q_funsor = LifoQueue(), LifoQueue() q_pyro.put(Trace()) q_funsor.put(Trace()) while not q_pyro.empty() and not q_funsor.empty(): with pyro_backend("pyro"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_pyro = handlers.trace( handlers.queue( guide, q_pyro, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_pyro = handlers.trace( handlers.replay(model, trace=guide_tr_pyro)).get_trace(**kwargs) with pyro_backend("contrib.funsor"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_funsor = handlers.trace( handlers.queue( guide, q_funsor, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_funsor = handlers.trace( handlers.replay(model, trace=guide_tr_funsor)).get_trace(**kwargs) # make sure all dimensions were cleaned up assert _DIM_STACK.local_frame is _DIM_STACK.global_frame assert (not _DIM_STACK.global_frame.name_to_dim and not _DIM_STACK.global_frame.dim_to_name) assert _DIM_STACK.outermost is None tr_pyro = prune_subsample_sites(tr_pyro.copy()) tr_funsor = prune_subsample_sites(tr_funsor.copy()) _check_traces(tr_pyro, tr_funsor)
def test_not_implemented(backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend): pyro.sample # should be implemented pyro.param # should be implemented with pytest.raises(NotImplementedError): pyro.nonexistent_primitive
def test_nesting(): def testing(): with pyro.markov(): v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(1), funsor.Bint[2])]), 'real')) print(1, v1.shape) # shapes should alternate assert v1.shape == (2,) with pyro.markov(): v2 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(2), funsor.Bint[2])]), 'real')) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with pyro.markov(): v3 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(3), funsor.Bint[2])]), 'real')) print(3, v3.shape) # shapes should alternate assert v3.shape == (2,) with pyro.markov(): v4 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(4), funsor.Bint[2])]), 'real')) print(4, v4.shape) # shapes should alternate assert v4.shape == (2, 1) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing()
def test_trace_handler(model, backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend), handlers.seed(rng_seed=2): f = MODELS[model]() model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) # should be implemented handlers.trace(model).get_trace(*model_args, **model_kwargs)
def test_model_sample(model, backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend), handlers.seed(rng_seed=2): f = MODELS[model]() model, model_args, model_kwargs = f['model'], f.get( 'model_args', ()), f.get('model_kwargs', {}) model(*model_args, **model_kwargs)
def test_model_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def test_guide_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"), \ pytest.raises( NotImplementedError, match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def _guide_from_model(model): try: with pyro_backend("contrib.funsor"): return handlers.block( infer.config_enumerate(model, default="parallel"), lambda msg: msg.get("is_observed", False)) except KeyError: # for test collection without funsor return model
def test_mcmc_interface(model, backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend), handlers.seed(rng_seed=20): f = MODELS[model]() model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) nuts_kernel = infer.NUTS(model=model) mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10) mcmc.run(*args, **kwargs) mcmc.summary()
def test_fresh_inputs_to_funsor(): def testing(): x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: "x"}) assert set(x.inputs) == {"x"} px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"}) assert px.inputs["x"].dtype == 2 and px.inputs["y"].dtype == 3 with pyro_backend("contrib.funsor"), NamedMessenger(): testing()
def test_rng_seed(backend): def model(): return pyro.sample("x", dist.Normal(0, 1)) with pyro_backend(backend): with handlers.seed(rng_seed=0): expected = model() with handlers.seed(rng_seed=0): actual = model() assert ops.allclose(actual, expected)
def test_generate_data(backend): def model(): loc = pyro.param("loc", torch.tensor(2.0)) scale = pyro.param("scale", torch.tensor(1.0)) x = pyro.sample("x", dist.Normal(loc, scale)) return x with pyro_backend(backend): data = model() data = data.data assert data.shape == ()
def test_register_backend(model): pytest.importorskip("pyro") register_backend("foo", { "infer": "pyro.contrib.minipyro", "optim": "pyro.contrib.minipyro", "pyro": "pyro.contrib.minipyro", }) with pyro_backend("foo"): f = MODELS[model]() model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) handlers.trace(model).get_trace(*model_args, **model_kwargs)
def test_staggered_fresh(): def testing(): for i in pyro.markov(range(12)): if i % 4 == 0: fv2 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: "a"}) v2 = pyro.to_data(fv2) assert v2.shape == (2,) print("a", v2.shape) print("a", fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing()
def test_nonempty_model_empty_guide_ok(backend, jit): def model(data): loc = pyro.param("loc", torch.tensor(0.0)) pyro.sample("x", dist.Normal(loc, 1.), obs=data) def guide(data): pass data = torch.tensor(2.) with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo, data)
def test_staggered(): def testing(): for i in pyro.markov(range(12)): if i % 4 == 0: v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real')) fv2 = pyro.to_funsor(v2, funsor.Real) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing()
def test_mean_field_warn(backend): def model(): x = pyro.sample("x", dist.Normal(0., 1.)) pyro.sample("y", dist.Normal(x, 1.)) def guide(): loc = pyro.param("loc", torch.tensor(0.)) y = pyro.sample("y", dist.Normal(loc, 1.)) pyro.sample("x", dist.Normal(y, 1.)) with pyro_backend(backend): elbo = infer.TraceMeanField_ELBO() assert_warning(model, guide, elbo)
def test_gaussian_probit_hmm_smoke(exact, jit): def model(data): T, N, D = data.shape # time steps, individuals, features # Gaussian initial distribution. init_loc = pyro.param("init_loc", torch.zeros(D)) init_scale = pyro.param("init_scale", 1e-2 * torch.eye(D), constraint=constraints.lower_cholesky) # Linear dynamics with Gaussian noise. trans_const = pyro.param("trans_const", torch.zeros(D)) trans_coeff = pyro.param("trans_coeff", torch.eye(D)) noise = pyro.param("noise", 1e-2 * torch.eye(D), constraint=constraints.lower_cholesky) obs_plate = pyro.plate("channel", D, dim=-1) with pyro.plate("data", N, dim=-2): state = None for t in range(T): # Transition. if t == 0: loc = init_loc scale_tril = init_scale else: loc = trans_const + funsor.torch.torch_tensordot( trans_coeff, state, 1) scale_tril = noise state = pyro.sample("state_{}".format(t), dist.MultivariateNormal(loc, scale_tril), infer={"exact": exact}) # Factorial probit likelihood model. with obs_plate: pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=state["channel"]), obs=data[t]) def guide(data): pass data = torch.distributions.Bernoulli(0.5).sample((3, 4, 2)) with pyro_backend("funsor"): Elbo = infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO elbo = Elbo() adam = optim.Adam({"lr": 1e-3}) svi = infer.SVI(model, guide, adam, elbo) svi.step(data)
def main(args): funsor.set_backend("torch") # Define a basic model with a single Normal latent random variable `loc` # and a batch of Normally distributed observations. def model(data): loc = pyro.sample("loc", dist.Normal(0., 1.)) with pyro.plate("data", len(data), dim=-1): pyro.sample("obs", dist.Normal(loc, 1.), obs=data) # Define a guide (i.e. variational distribution) with a Normal # distribution over the latent random variable `loc`. def guide(data): guide_loc = pyro.param("guide_loc", torch.tensor(0.)) guide_scale = pyro.param("guide_scale", torch.tensor(1.), constraint=constraints.positive) pyro.sample("loc", dist.Normal(guide_loc, guide_scale)) # Generate some data. torch.manual_seed(0) data = torch.randn(100) + 3.0 # Because the API in minipyro matches that of Pyro proper, # training code works with generic Pyro implementations. with pyro_backend(args.backend), interpretation(MonteCarlo()): # Construct an SVI object so we can do variational inference on our # model/guide pair. Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO elbo = Elbo() adam = optim.Adam({"lr": args.learning_rate}) svi = infer.SVI(model, guide, adam, elbo) # Basic training loop pyro.get_param_store().clear() for step in range(args.num_steps): loss = svi.step(data) if args.verbose and step % 100 == 0: print("step {} loss = {}".format(step, loss)) # Report the final values of the variational parameters # in the guide after training. if args.verbose: for name in pyro.get_param_store(): value = pyro.param(name).data print("{} = {}".format(name, value.detach().cpu().numpy())) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
def test_elbo_plate_plate(backend, outer_dim, inner_dim): with pyro_backend(backend): pyro.get_param_store().clear() num_particles = 1 q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d) with context1: pyro.sample("x", d) with context2: pyro.sample("y", d) with context1, context2: pyro.sample("z", d) def guide(): d = dist.Categorical(pyro.param("q")) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": "parallel"}) with context1: pyro.sample("x", d, infer={"enumerate": "parallel"}) with context2: pyro.sample("y", d, infer={"enumerate": "parallel"}) with context1, context2: pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = kl_divergence( torch.distributions.Categorical(funsor.to_data(q)), torch.distributions.Categorical(funsor.to_data(p))) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [funsor.to_data(q)])[0] elbo = infer.TraceEnum_ELBO(num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=True) elbo = elbo.differentiable_loss if backend == "pyro" else elbo actual_loss = funsor.to_data(elbo(model, guide)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param('q')).grad assert ops.allclose(actual_loss, expected_loss, atol=1e-5) assert ops.allclose(actual_grad, expected_grad, atol=1e-5)
def test_generate_data_plate(backend): num_points = 1000 def model(data=None): loc = pyro.param("loc", torch.tensor(2.0)) scale = pyro.param("scale", torch.tensor(1.0)) with pyro.plate("data", 1000, dim=-1): x = pyro.sample("x", dist.Normal(loc, scale), obs=data) return x with pyro_backend(backend): data = model().data assert data.shape == (num_points, ) mean = data.sum().item() / num_points assert 1.9 <= mean <= 2.1
def test_optimizer(backend, optim_name, jit): def model(data): p = pyro.param("p", torch.tensor(0.5)) pyro.sample("x", dist.Bernoulli(p), obs=data) def guide(data): pass data = torch.tensor(0.) with pyro_backend(backend): pyro.get_param_store().clear() Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) optimizer = getattr(optim, optim_name)({"lr": 1e-6}) inference = infer.SVI(model, guide, optimizer, elbo) for i in range(2): inference.step(data)
def test_constraints(backend, jit): data = torch.tensor(0.5) def model(): locs = pyro.param("locs", torch.randn(3), constraint=constraints.real) scales = pyro.param("scales", torch.randn(3).exp(), constraint=constraints.positive) p = torch.tensor([0.5, 0.3, 0.2]) x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) def guide(): q = pyro.param("q", torch.randn(3).exp(), constraint=constraints.simplex) pyro.sample("x", dist.Categorical(q)) with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo)
def test_iteration_fresh(): def testing(): for i in pyro.markov(range(5)): fv1 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: str(i)}) fv2 = pyro.to_funsor(torch.ones(2), funsor.Real, dim_to_name={-1: "a"}) v1 = pyro.to_data(fv1) v2 = pyro.to_data(fv2) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print("a", v2.shape) # shapes should stay the same print("a", fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing()
def test_plate_ok(backend, jit): data = torch.randn(10) def model(): locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5])) p = torch.tensor([0.2, 0.3, 0.5]) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) def guide(): p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2])) with pyro.plate("plate", len(data), dim=-1): pyro.sample("x", dist.Categorical(p)) with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo)
def test_nested_plate_plate_ok(backend, jit): data = torch.randn(2, 3) def model(): loc = torch.tensor(3.0) with pyro.plate("plate_outer", data.size(-1), dim=-1): x = pyro.sample("x", dist.Normal(loc, 1.)) with pyro.plate("plate_inner", data.size(-2), dim=-2): pyro.sample("y", dist.Normal(x, 1.), obs=data) def guide(): loc = pyro.param("loc", torch.tensor(0.)) scale = pyro.param("scale", torch.tensor(1.)) with pyro.plate("plate_outer", data.size(-1), dim=-1): pyro.sample("x", dist.Normal(loc, scale)) with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo)
def test_iteration(): def testing(): for i in pyro.markov(range(5)): v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(i), funsor.Bint[2])]), 'real')) v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real')) fv1 = pyro.to_funsor(v1, funsor.Real) fv2 = pyro.to_funsor(v2, funsor.Real) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing()
def test_local_param_ok(backend, jit): data = torch.randn(10) def model(): locs = pyro.param("locs", torch.tensor([-1., 0., 1.])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) def guide(): with pyro.plate("plate", len(data), dim=-1): p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1) pyro.sample("x", dist.Categorical(p)) return p with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo) # Check that pyro.param() can be called without init_value. expected = guide() actual = pyro.param("p") assert ops.allclose(actual, expected)
def backend(request): with pyro_backend(request.param): yield
def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy): def model(): x = pyro.sample("x0", dist.Categorical(pyro.param("q0"))) with pyro.plate("local", 3): for i in range(1, depth): x = pyro.sample( "x{}".format(i), dist.Categorical(pyro.param("q{}".format(i))[..., x, :])) with pyro.plate("data", 4): pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), obs=data) with pyro_backend("pyro"): # initialize qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))] for i in range(1, depth): qs.append( pyro.param("q{}".format(i), torch.randn(2, 2).abs().detach().requires_grad_(), constraint=constraints.simplex)) qs.append( pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True))) qs = [q.unconstrained() for q in qs] data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype, device=qs[-1].device) with pyro_backend("pyro"): elbo = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) enum_model = infer.config_enumerate(model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) expected_loss = ( -elbo.differentiable_loss(enum_model, lambda: None)).exp() expected_grads = grad(expected_loss, qs) with pyro_backend("contrib.funsor"): tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = infer.config_enumerate(model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp() actual_grads = grad(actual_loss, qs) prec = 0.05 assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ "\nexpected grad = {}".format( expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format( actual_grad.detach().cpu().numpy()), ]))