Example #1
0
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs):
    """
    Assert that enumeration runs...
    """
    with pyro_backend("pyro"):
        pyro.clear_param_store()

    if guide is None:
        guide = lambda **kwargs: None  # noqa: E731

    q_pyro, q_funsor = LifoQueue(), LifoQueue()
    q_pyro.put(Trace())
    q_funsor.put(Trace())

    while not q_pyro.empty() and not q_funsor.empty():
        with pyro_backend("pyro"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_pyro = handlers.trace(
                    handlers.queue(
                        guide,
                        q_pyro,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_pyro = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_pyro)).get_trace(**kwargs)

        with pyro_backend("contrib.funsor"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_funsor = handlers.trace(
                    handlers.queue(
                        guide,
                        q_funsor,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_funsor = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_funsor)).get_trace(**kwargs)

        # make sure all dimensions were cleaned up
        assert _DIM_STACK.local_frame is _DIM_STACK.global_frame
        assert (not _DIM_STACK.global_frame.name_to_dim
                and not _DIM_STACK.global_frame.dim_to_name)
        assert _DIM_STACK.outermost is None

        tr_pyro = prune_subsample_sites(tr_pyro.copy())
        tr_funsor = prune_subsample_sites(tr_funsor.copy())
        _check_traces(tr_pyro, tr_funsor)
Example #2
0
def test_not_implemented(backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend):
        pyro.sample  # should be implemented
        pyro.param  # should be implemented
        with pytest.raises(NotImplementedError):
            pyro.nonexistent_primitive
Example #3
0
def test_nesting():

    def testing():

        with pyro.markov():
            v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(1), funsor.Bint[2])]), 'real'))
            print(1, v1.shape)  # shapes should alternate
            assert v1.shape == (2,)

            with pyro.markov():
                v2 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(2), funsor.Bint[2])]), 'real'))
                print(2, v2.shape)  # shapes should alternate
                assert v2.shape == (2, 1)

                with pyro.markov():
                    v3 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(3), funsor.Bint[2])]), 'real'))
                    print(3, v3.shape)  # shapes should alternate
                    assert v3.shape == (2,)

                    with pyro.markov():
                        v4 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(4), funsor.Bint[2])]), 'real'))
                        print(4, v4.shape)  # shapes should alternate

                        assert v4.shape == (2, 1)

    with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1):
        testing()
Example #4
0
def test_trace_handler(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=2):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        # should be implemented
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
Example #5
0
def test_model_sample(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=2):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get(
            'model_args', ()), f.get('model_kwargs', {})
        model(*model_args, **model_kwargs)
Example #6
0
def test_model_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        model = infer.config_enumerate(model, default="parallel")
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Example #7
0
def test_guide_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"), \
        pytest.raises(
            NotImplementedError,
            match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"):

        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Example #8
0
def _guide_from_model(model):
    try:
        with pyro_backend("contrib.funsor"):
            return handlers.block(
                infer.config_enumerate(model, default="parallel"),
                lambda msg: msg.get("is_observed", False))
    except KeyError:  # for test collection without funsor
        return model
Example #9
0
def test_mcmc_interface(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=20):
        f = MODELS[model]()
        model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        nuts_kernel = infer.NUTS(model=model)
        mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
        mcmc.run(*args, **kwargs)
        mcmc.summary()
Example #10
0
def test_fresh_inputs_to_funsor():

    def testing():
        x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: "x"})
        assert set(x.inputs) == {"x"}
        px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"})
        assert px.inputs["x"].dtype == 2 and px.inputs["y"].dtype == 3

    with pyro_backend("contrib.funsor"), NamedMessenger():
        testing()
Example #11
0
def test_rng_seed(backend):
    def model():
        return pyro.sample("x", dist.Normal(0, 1))

    with pyro_backend(backend):
        with handlers.seed(rng_seed=0):
            expected = model()
        with handlers.seed(rng_seed=0):
            actual = model()
        assert ops.allclose(actual, expected)
Example #12
0
def test_generate_data(backend):
    def model():
        loc = pyro.param("loc", torch.tensor(2.0))
        scale = pyro.param("scale", torch.tensor(1.0))
        x = pyro.sample("x", dist.Normal(loc, scale))
        return x

    with pyro_backend(backend):
        data = model()
        data = data.data
        assert data.shape == ()
Example #13
0
def test_register_backend(model):
    pytest.importorskip("pyro")
    register_backend("foo", {
        "infer": "pyro.contrib.minipyro",
        "optim": "pyro.contrib.minipyro",
        "pyro": "pyro.contrib.minipyro",
    })
    with pyro_backend("foo"):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
Example #14
0
def test_staggered_fresh():
    def testing():
        for i in pyro.markov(range(12)):
            if i % 4 == 0:
                fv2 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: "a"})
                v2 = pyro.to_data(fv2)
                assert v2.shape == (2,)
                print("a", v2.shape)
                print("a", fv2.inputs)

    with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1):
        testing()
Example #15
0
def test_nonempty_model_empty_guide_ok(backend, jit):
    def model(data):
        loc = pyro.param("loc", torch.tensor(0.0))
        pyro.sample("x", dist.Normal(loc, 1.), obs=data)

    def guide(data):
        pass

    data = torch.tensor(2.)
    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo, data)
Example #16
0
def test_staggered():

    def testing():
        for i in pyro.markov(range(12)):
            if i % 4 == 0:
                v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real'))
                fv2 = pyro.to_funsor(v2, funsor.Real)
                assert v2.shape == (2,)
                print('a', v2.shape)
                print('a', fv2.inputs)

    with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1):
        testing()
Example #17
0
def test_mean_field_warn(backend):
    def model():
        x = pyro.sample("x", dist.Normal(0., 1.))
        pyro.sample("y", dist.Normal(x, 1.))

    def guide():
        loc = pyro.param("loc", torch.tensor(0.))
        y = pyro.sample("y", dist.Normal(loc, 1.))
        pyro.sample("x", dist.Normal(y, 1.))

    with pyro_backend(backend):
        elbo = infer.TraceMeanField_ELBO()
        assert_warning(model, guide, elbo)
Example #18
0
def test_gaussian_probit_hmm_smoke(exact, jit):
    def model(data):
        T, N, D = data.shape  # time steps, individuals, features

        # Gaussian initial distribution.
        init_loc = pyro.param("init_loc", torch.zeros(D))
        init_scale = pyro.param("init_scale",
                                1e-2 * torch.eye(D),
                                constraint=constraints.lower_cholesky)

        # Linear dynamics with Gaussian noise.
        trans_const = pyro.param("trans_const", torch.zeros(D))
        trans_coeff = pyro.param("trans_coeff", torch.eye(D))
        noise = pyro.param("noise",
                           1e-2 * torch.eye(D),
                           constraint=constraints.lower_cholesky)

        obs_plate = pyro.plate("channel", D, dim=-1)
        with pyro.plate("data", N, dim=-2):
            state = None
            for t in range(T):
                # Transition.
                if t == 0:
                    loc = init_loc
                    scale_tril = init_scale
                else:
                    loc = trans_const + funsor.torch.torch_tensordot(
                        trans_coeff, state, 1)
                    scale_tril = noise
                state = pyro.sample("state_{}".format(t),
                                    dist.MultivariateNormal(loc, scale_tril),
                                    infer={"exact": exact})

                # Factorial probit likelihood model.
                with obs_plate:
                    pyro.sample("obs_{}".format(t),
                                dist.Bernoulli(logits=state["channel"]),
                                obs=data[t])

    def guide(data):
        pass

    data = torch.distributions.Bernoulli(0.5).sample((3, 4, 2))

    with pyro_backend("funsor"):
        Elbo = infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": 1e-3})
        svi = infer.SVI(model, guide, adam, elbo)
        svi.step(data)
Example #19
0
def main(args):
    funsor.set_backend("torch")

    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale", torch.tensor(1.),
                                 constraint=constraints.positive)
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args.backend), interpretation(MonteCarlo()):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": args.learning_rate})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args.num_steps):
            loss = svi.step(data)
            if args.verbose and step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        if args.verbose:
            for name in pyro.get_param_store():
                value = pyro.param(name).data
                print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
Example #20
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert ops.allclose(actual_loss, expected_loss, atol=1e-5)
        assert ops.allclose(actual_grad, expected_grad, atol=1e-5)
Example #21
0
def test_generate_data_plate(backend):
    num_points = 1000

    def model(data=None):
        loc = pyro.param("loc", torch.tensor(2.0))
        scale = pyro.param("scale", torch.tensor(1.0))
        with pyro.plate("data", 1000, dim=-1):
            x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
        return x

    with pyro_backend(backend):
        data = model().data
        assert data.shape == (num_points, )
        mean = data.sum().item() / num_points
        assert 1.9 <= mean <= 2.1
Example #22
0
def test_optimizer(backend, optim_name, jit):
    def model(data):
        p = pyro.param("p", torch.tensor(0.5))
        pyro.sample("x", dist.Bernoulli(p), obs=data)

    def guide(data):
        pass

    data = torch.tensor(0.)
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        optimizer = getattr(optim, optim_name)({"lr": 1e-6})
        inference = infer.SVI(model, guide, optimizer, elbo)
        for i in range(2):
            inference.step(data)
Example #23
0
def test_constraints(backend, jit):
    data = torch.tensor(0.5)

    def model():
        locs = pyro.param("locs", torch.randn(3), constraint=constraints.real)
        scales = pyro.param("scales", torch.randn(3).exp(), constraint=constraints.positive)
        p = torch.tensor([0.5, 0.3, 0.2])
        x = pyro.sample("x", dist.Categorical(p))
        pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)

    def guide():
        q = pyro.param("q", torch.randn(3).exp(), constraint=constraints.simplex)
        pyro.sample("x", dist.Categorical(q))

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)
Example #24
0
def test_iteration_fresh():
    def testing():
        for i in pyro.markov(range(5)):
            fv1 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: str(i)})
            fv2 = pyro.to_funsor(torch.ones(2), funsor.Real, dim_to_name={-1: "a"})
            v1 = pyro.to_data(fv1)
            v2 = pyro.to_data(fv2)
            print(i, v1.shape)  # shapes should alternate
            if i % 2 == 0:
                assert v1.shape == (2,)
            else:
                assert v1.shape == (2, 1, 1)
            assert v2.shape == (2, 1)
            print(i, fv1.inputs)
            print("a", v2.shape)  # shapes should stay the same
            print("a", fv2.inputs)

    with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1):
        testing()
Example #25
0
def test_plate_ok(backend, jit):
    data = torch.randn(10)

    def model():
        locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5]))
        p = torch.tensor([0.2, 0.3, 0.5])
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(p))
            pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

    def guide():
        p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2]))
        with pyro.plate("plate", len(data), dim=-1):
            pyro.sample("x", dist.Categorical(p))

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)
Example #26
0
def test_nested_plate_plate_ok(backend, jit):
    data = torch.randn(2, 3)

    def model():
        loc = torch.tensor(3.0)
        with pyro.plate("plate_outer", data.size(-1), dim=-1):
            x = pyro.sample("x", dist.Normal(loc, 1.))
            with pyro.plate("plate_inner", data.size(-2), dim=-2):
                pyro.sample("y", dist.Normal(x, 1.), obs=data)

    def guide():
        loc = pyro.param("loc", torch.tensor(0.))
        scale = pyro.param("scale", torch.tensor(1.))
        with pyro.plate("plate_outer", data.size(-1), dim=-1):
            pyro.sample("x", dist.Normal(loc, scale))

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)
Example #27
0
def test_iteration():

    def testing():
        for i in pyro.markov(range(5)):
            v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(i), funsor.Bint[2])]), 'real'))
            v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real'))
            fv1 = pyro.to_funsor(v1, funsor.Real)
            fv2 = pyro.to_funsor(v2, funsor.Real)
            print(i, v1.shape)  # shapes should alternate
            if i % 2 == 0:
                assert v1.shape == (2,)
            else:
                assert v1.shape == (2, 1, 1)
            assert v2.shape == (2, 1)
            print(i, fv1.inputs)
            print('a', v2.shape)  # shapes should stay the same
            print('a', fv2.inputs)

    with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1):
        testing()
Example #28
0
def test_local_param_ok(backend, jit):
    data = torch.randn(10)

    def model():
        locs = pyro.param("locs", torch.tensor([-1., 0., 1.]))
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
            pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

    def guide():
        with pyro.plate("plate", len(data), dim=-1):
            p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1)
            pyro.sample("x", dist.Categorical(p))
        return p

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)

        # Check that pyro.param() can be called without init_value.
        expected = guide()
        actual = pyro.param("p")
        assert ops.allclose(actual, expected)
Example #29
0
def backend(request):
    with pyro_backend(request.param):
        yield
Example #30
0
def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy):
    def model():
        x = pyro.sample("x0", dist.Categorical(pyro.param("q0")))
        with pyro.plate("local", 3):
            for i in range(1, depth):
                x = pyro.sample(
                    "x{}".format(i),
                    dist.Categorical(pyro.param("q{}".format(i))[..., x, :]))
            with pyro.plate("data", 4):
                pyro.sample("y",
                            dist.Bernoulli(pyro.param("qy")[..., x]),
                            obs=data)

    with pyro_backend("pyro"):
        # initialize
        qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))]
        for i in range(1, depth):
            qs.append(
                pyro.param("q{}".format(i),
                           torch.randn(2, 2).abs().detach().requires_grad_(),
                           constraint=constraints.simplex))
        qs.append(
            pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True)))
        qs = [q.unconstrained() for q in qs]
        data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype,
                                           device=qs[-1].device)

    with pyro_backend("pyro"):
        elbo = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
        enum_model = infer.config_enumerate(model,
                                            default="parallel",
                                            expand=False,
                                            num_samples=num_samples,
                                            tmc=tmc_strategy)
        expected_loss = (
            -elbo.differentiable_loss(enum_model, lambda: None)).exp()
        expected_grads = grad(expected_loss, qs)

    with pyro_backend("contrib.funsor"):
        tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
        tmc_model = infer.config_enumerate(model,
                                           default="parallel",
                                           expand=False,
                                           num_samples=num_samples,
                                           tmc=tmc_strategy)
        actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp()
        actual_grads = grad(actual_loss, qs)

    prec = 0.05
    assert_equal(actual_loss,
                 expected_loss,
                 prec=prec,
                 msg="".join([
                     "\nexpected loss = {}".format(expected_loss),
                     "\n  actual loss = {}".format(actual_loss),
                 ]))

    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        assert_equal(actual_grad,
                     expected_grad,
                     prec=prec,
                     msg="".join([
                         "\nexpected grad = {}".format(
                             expected_grad.detach().cpu().numpy()),
                         "\n  actual grad = {}".format(
                             actual_grad.detach().cpu().numpy()),
                     ]))