def test_guide_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"), \ pytest.raises( NotImplementedError, match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def test_model_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)