예제 #1
0
def assert_ok(model, guide, elbo, *args, **kwargs):
    """
    Assert that inference works without warnings or errors.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    for i in range(2):
        inference.step(*args, **kwargs)
예제 #2
0
def assert_error(model, guide, elbo, match=None):
    """
    Assert that inference fails with an error.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model,  guide, adam, elbo)
    with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError),
                       match=match):
        inference.step()
예제 #3
0
def assert_warning(model, guide, elbo):
    """
    Assert that inference works but with a warning.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        inference.step()
        assert len(w), 'No warnings were raised'
        for warning in w:
            print(warning)
예제 #4
0
def main(args):
    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale",
                                 torch.tensor(1.),
                                 constraint=constraints.positive)
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args.backend), interpretation(monte_carlo):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": args.learning_rate})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args.num_steps):
            loss = svi.step(data)
            if args.verbose and step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        if args.verbose:
            for name in pyro.get_param_store():
                value = pyro.param(name).data
                print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
예제 #5
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert_close(actual_loss, expected_loss, atol=1e-5)
        assert_close(actual_grad, expected_grad, atol=1e-5)
예제 #6
0
def test_optimizer(backend, optim_name, jit):
    def model(data):
        p = pyro.param("p", torch.tensor(0.5))
        pyro.sample("x", dist.Bernoulli(p), obs=data)

    def guide(data):
        pass

    data = torch.tensor(0.)
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        optimizer = getattr(optim, optim_name)({"lr": 1e-6})
        inference = infer.SVI(model, guide, optimizer, elbo)
        for i in range(2):
            inference.step(data)
예제 #7
0
def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4):
    # copied from pyro
    expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(
        actual_loss)
    assert_close(actual_loss, expected_loss, atol=atol, rtol=rtol)
    names = pyro.get_param_store().keys()
    params = []
    for name in names:
        params.append(funsor.to_data(pyro.param(name)).unconstrained())
    actual_grads = grad(actual_loss,
                        params,
                        allow_unused=True,
                        retain_graph=True)
    expected_grads = grad(expected_loss,
                          params,
                          allow_unused=True,
                          retain_graph=True)
    for name, actual_grad, expected_grad in zip(names, actual_grads,
                                                expected_grads):
        if actual_grad is None or expected_grad is None:
            continue
        assert_close(actual_grad, expected_grad, atol=atol, rtol=rtol)
예제 #8
0
def build_svi(model, guide, elbo):
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    return infer.SVI(model, guide, adam, elbo)