Esempio n. 1
0
def test_trace_handler(model, backend):
    with pyro_backend(backend), handlers.seed(
            rng_seed=2), xfail_if_not_implemented():
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get(
            'model_args', ()), f.get('model_kwargs', {})
        # should be implemented
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
Esempio n. 2
0
def test_constraints(backend, jit):
    data = torch.tensor(0.5)

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        steps = 2
        assert_ok(constrained_model, guide_constrained_model, elbo, steps, data)
Esempio n. 3
0
def test_mcmc_interface(model, backend):
    with pyro_backend(backend), handlers.seed(rng_seed=20):
        f = MODELS[model]()
        model, args, kwargs = f['model'], f.get('model_args',
                                                ()), f.get('model_kwargs', {})
        nuts_kernel = infer.NUTS(model=model)
        mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
        mcmc.run(*args, **kwargs)
        mcmc.summary()
Esempio n. 4
0
def test_generate_data(backend):
    def model(data=None):
        loc = pyro.param("loc", torch.tensor(2.0))
        scale = pyro.param("scale", torch.tensor(1.0))
        x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
        return x

    with pyro_backend(backend):
        data = model().data
        assert data.shape == ()
Esempio n. 5
0
def elbo_test_case(backend, jit, expected_elbo, data, steps=None):
    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        if steps:
            assert_ok(constrained_model, guide_constrained_model, elbo, steps, data)
        if backend == "pyro":
            # TODO: this is a difference between the two implementations
            elbo = elbo.loss
        assert elbo(constrained_model, guide_constrained_model, data) == approx(expected_elbo, rel=0.1)
Esempio n. 6
0
def test_rng_seed(model, backend):
    with pyro_backend(backend), handlers.seed(
            rng_seed=2), xfail_if_not_implemented():
        f = MODELS[model]()
        model, model_args = f['model'], f.get('model_args', ())
        with handlers.seed(rng_seed=0):
            expected = model(*model_args)
        if expected is None:
            pytest.skip()
        with handlers.seed(rng_seed=0):
            actual = model(*model_args)
        assert ops.allclose(actual, expected)
Esempio n. 7
0
def test_nonempty_model_empty_guide_ok(backend, jit):
    def model(data):
        loc = pyro.param("loc", torch.tensor(0.0))
        pyro.sample("x", dist.Normal(loc, 1.), obs=data)

    def guide(data):
        pass

    data = torch.tensor(2.)
    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo, data)
Esempio n. 8
0
def test_mean_field_warn(backend):
    def model():
        x = pyro.sample("x", dist.Normal(0., 1.))
        pyro.sample("y", dist.Normal(x, 1.))

    def guide():
        loc = pyro.param("loc", torch.tensor(0.))
        y = pyro.sample("y", dist.Normal(loc, 1.))
        pyro.sample("x", dist.Normal(y, 1.))

    with pyro_backend(backend):
        elbo = infer.TraceMeanField_ELBO()
        assert_warning(model, guide, elbo)
Esempio n. 9
0
def test_gaussian_probit_hmm_smoke(exact, jit):
    def model(data):
        T, N, D = data.shape  # time steps, individuals, features

        # Gaussian initial distribution.
        init_loc = pyro.param("init_loc", torch.zeros(D))
        init_scale = pyro.param("init_scale",
                                1e-2 * torch.eye(D),
                                constraint=constraints.lower_cholesky)

        # Linear dynamics with Gaussian noise.
        trans_const = pyro.param("trans_const", torch.zeros(D))
        trans_coeff = pyro.param("trans_coeff", torch.eye(D))
        noise = pyro.param("noise",
                           1e-2 * torch.eye(D),
                           constraint=constraints.lower_cholesky)

        obs_plate = pyro.plate("channel", D, dim=-1)
        with pyro.plate("data", N, dim=-2):
            state = None
            for t in range(T):
                # Transition.
                if t == 0:
                    loc = init_loc
                    scale_tril = init_scale
                else:
                    loc = trans_const + funsor.torch.torch_tensordot(
                        trans_coeff, state, 1)
                    scale_tril = noise
                state = pyro.sample("state_{}".format(t),
                                    dist.MultivariateNormal(loc, scale_tril),
                                    infer={"exact": exact})

                # Factorial probit likelihood model.
                with obs_plate:
                    pyro.sample("obs_{}".format(t),
                                dist.Bernoulli(logits=state["channel"]),
                                obs=data[t])

    def guide(data):
        pass

    data = torch.distributions.Bernoulli(0.5).sample((3, 4, 2))

    with pyro_backend("funsor"):
        Elbo = infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": 1e-3})
        svi = infer.SVI(model, guide, adam, elbo)
        svi.step(data)
Esempio n. 10
0
def main(args):
    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale",
                                 torch.tensor(1.),
                                 constraint=constraints.positive)
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args.backend), interpretation(monte_carlo):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": args.learning_rate})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args.num_steps):
            loss = svi.step(data)
            if args.verbose and step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        if args.verbose:
            for name in pyro.get_param_store():
                value = pyro.param(name).data
                print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
Esempio n. 11
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert_close(actual_loss, expected_loss, atol=1e-5)
        assert_close(actual_grad, expected_grad, atol=1e-5)
Esempio n. 12
0
def test_generate_data_plate(backend):
    num_points = 1000

    def model(data=None):
        loc = pyro.param("loc", torch.tensor(2.0))
        scale = pyro.param("scale", torch.tensor(1.0))
        with pyro.plate("data", 1000, dim=-1):
            x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
        return x

    with pyro_backend(backend):
        data = model().data
        assert data.shape == (num_points,)
        mean = float(ops.sum(data)) / num_points
        assert 1.9 <= mean <= 2.1
Esempio n. 13
0
def test_optimizer(backend, optim_name, jit):
    def model(data):
        p = pyro.param("p", torch.tensor(0.5))
        pyro.sample("x", dist.Bernoulli(p), obs=data)

    def guide(data):
        pass

    data = torch.tensor(0.)
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        optimizer = getattr(optim, optim_name)({"lr": 1e-6})
        inference = infer.SVI(model, guide, optimizer, elbo)
        for i in range(2):
            inference.step(data)
Esempio n. 14
0
def test_constraints(backend, jit):
    data = torch.tensor(0.5)

    def model():
        locs = pyro.param("locs", torch.randn(3), constraint=constraints.real)
        scales = pyro.param("scales", ops.exp(torch.randn(3)), constraint=constraints.positive)
        p = torch.tensor([0.5, 0.3, 0.2])
        x = pyro.sample("x", dist.Categorical(p))
        pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)

    def guide():
        q = pyro.param("q", ops.exp(torch.randn(3)), constraint=constraints.simplex)
        pyro.sample("x", dist.Categorical(q))

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)
Esempio n. 15
0
def test_plate_ok(backend, jit):
    data = torch.randn(10)

    def model():
        locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5]))
        p = torch.tensor([0.2, 0.3, 0.5])
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(p))
            pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data)

    def guide():
        p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2]))
        with pyro.plate("plate", len(data), dim=-1):
            pyro.sample("x", dist.Categorical(p))

    with pyro_backend(backend):
        elbo_factory = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = elbo_factory(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)
Esempio n. 16
0
def test_nested_plate_plate_ok(backend, jit):
    data = torch.randn(2, 3)

    def model():
        loc = torch.tensor(3.0)
        with pyro.plate("plate_outer", data.size(-1), dim=-1):
            x = pyro.sample("x", dist.Normal(loc, 1.0))
            with pyro.plate("plate_inner", data.size(-2), dim=-2):
                pyro.sample("y", dist.Normal(x, 1.0), obs=data)

    def guide():
        loc = pyro.param("loc", torch.tensor(0.0))
        scale = pyro.param("scale", torch.tensor(1.0))
        with pyro.plate("plate_outer", data.size(-1), dim=-1):
            pyro.sample("x", dist.Normal(loc, scale))

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)
Esempio n. 17
0
def test_local_param_ok(backend, jit):
    data = torch.randn(10)

    def model():
        locs = pyro.param("locs", torch.tensor([-1.0, 0.0, 1.0]))
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
            pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data)

    def guide():
        with pyro.plate("plate", len(data), dim=-1):
            p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1)
            pyro.sample("x", dist.Categorical(p))
        return p

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)

        # Check that pyro.param() can be called without init_value.
        expected = guide()
        actual = pyro.param("p")
        assert_close(actual, expected)
Esempio n. 18
0
def test_elbo_enumerate_plate_7(backend):
    #  Guide    Model
    #    a -----> b
    #    |        |
    #  +-|--------|----------------+
    #  | V        V                |
    #  | c -----> d -----> e   N=2 |
    #  +---------------------------+
    # This tests a mixture of model and guide enumeration.
    with pyro_backend(backend):
        pyro.param("model_probs_a",
                   torch.tensor([0.45, 0.55]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_b",
                   torch.tensor([[0.6, 0.4], [0.4, 0.6]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_c",
                   torch.tensor([[0.75, 0.25], [0.55, 0.45]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_d",
                   torch.tensor([[[0.4, 0.6], [0.3, 0.7]],
                                 [[0.3, 0.7], [0.2, 0.8]]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_e",
                   torch.tensor([[0.75, 0.25], [0.55, 0.45]]),
                   constraint=constraints.simplex)
        pyro.param("guide_probs_a",
                   torch.tensor([0.35, 0.64]),
                   constraint=constraints.simplex)
        pyro.param(
            "guide_probs_c",
            torch.tensor([[0., 1.], [1., 0.]]),  # deterministic
            constraint=constraints.simplex)

        def auto_model(data):
            probs_a = pyro.param("model_probs_a")
            probs_b = pyro.param("model_probs_b")
            probs_c = pyro.param("model_probs_c")
            probs_d = pyro.param("model_probs_d")
            probs_e = pyro.param("model_probs_e")
            a = pyro.sample("a", dist.Categorical(probs_a))
            b = pyro.sample("b",
                            dist.Categorical(probs_b[a]),
                            infer={"enumerate": "parallel"})
            with pyro.plate("data", 2, dim=-1):
                c = pyro.sample("c", dist.Categorical(probs_c[a]))
                d = pyro.sample("d",
                                dist.Categorical(Vindex(probs_d)[b, c]),
                                infer={"enumerate": "parallel"})
                pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data)

        def auto_guide(data):
            probs_a = pyro.param("guide_probs_a")
            probs_c = pyro.param("guide_probs_c")
            a = pyro.sample("a",
                            dist.Categorical(probs_a),
                            infer={"enumerate": "parallel"})
            with pyro.plate("data", 2, dim=-1):
                pyro.sample("c", dist.Categorical(probs_c[a]))

        def hand_model(data):
            probs_a = pyro.param("model_probs_a")
            probs_b = pyro.param("model_probs_b")
            probs_c = pyro.param("model_probs_c")
            probs_d = pyro.param("model_probs_d")
            probs_e = pyro.param("model_probs_e")
            a = pyro.sample("a", dist.Categorical(probs_a))
            b = pyro.sample("b",
                            dist.Categorical(probs_b[a]),
                            infer={"enumerate": "parallel"})
            for i in range(2):
                c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))
                d = pyro.sample("d_{}".format(i),
                                dist.Categorical(Vindex(probs_d)[b, c]),
                                infer={"enumerate": "parallel"})
                pyro.sample("obs_{}".format(i),
                            dist.Categorical(probs_e[d]),
                            obs=data[i])

        def hand_guide(data):
            probs_a = pyro.param("guide_probs_a")
            probs_c = pyro.param("guide_probs_c")
            a = pyro.sample("a",
                            dist.Categorical(probs_a),
                            infer={"enumerate": "parallel"})
            for i in range(2):
                pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))

        data = torch.tensor([0, 0])
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        auto_loss = elbo(auto_model, auto_guide, data)
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        hand_loss = elbo(hand_model, hand_guide, data)
        _check_loss_and_grads(hand_loss, auto_loss)
Esempio n. 19
0
def test_not_implemented(backend):
    with pyro_backend(backend):
        pyro.sample  # should be implemented
        pyro.param  # should be implemented
        with pytest.raises(NotImplementedError):
            pyro.nonexistent_primitive
Esempio n. 20
0
def test_elbo_enumerate_plates_1(backend):
    #  +-----------------+
    #  | a ----> b   M=2 |
    #  +-----------------+
    #  +-----------------+
    #  | c ----> d   N=3 |
    #  +-----------------+
    # This tests two unrelated plates.
    # Each should remain uncontracted.
    with pyro_backend(backend):
        pyro.param("probs_a",
                   torch.tensor([0.45, 0.55]),
                   constraint=constraints.simplex)
        pyro.param("probs_b",
                   torch.tensor([[0.6, 0.4], [0.4, 0.6]]),
                   constraint=constraints.simplex)
        pyro.param("probs_c",
                   torch.tensor([0.75, 0.25]),
                   constraint=constraints.simplex)
        pyro.param("probs_d",
                   torch.tensor([[0.4, 0.6], [0.3, 0.7]]),
                   constraint=constraints.simplex)
        b_data = torch.tensor([0, 1])
        d_data = torch.tensor([0, 0, 1])

        def auto_model():
            probs_a = pyro.param("probs_a")
            probs_b = pyro.param("probs_b")
            probs_c = pyro.param("probs_c")
            probs_d = pyro.param("probs_d")
            with pyro.plate("a_axis", 2, dim=-1):
                a = pyro.sample("a",
                                dist.Categorical(probs_a),
                                infer={"enumerate": "parallel"})
                pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data)
            with pyro.plate("c_axis", 3, dim=-1):
                c = pyro.sample("c",
                                dist.Categorical(probs_c),
                                infer={"enumerate": "parallel"})
                pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data)

        def hand_model():
            probs_a = pyro.param("probs_a")
            probs_b = pyro.param("probs_b")
            probs_c = pyro.param("probs_c")
            probs_d = pyro.param("probs_d")
            for i in range(2):
                a = pyro.sample("a_{}".format(i),
                                dist.Categorical(probs_a),
                                infer={"enumerate": "parallel"})
                pyro.sample("b_{}".format(i),
                            dist.Categorical(probs_b[a]),
                            obs=b_data[i])
            for j in range(3):
                c = pyro.sample("c_{}".format(j),
                                dist.Categorical(probs_c),
                                infer={"enumerate": "parallel"})
                pyro.sample("d_{}".format(j),
                            dist.Categorical(probs_d[c]),
                            obs=d_data[j])

        def guide():
            pass

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        auto_loss = elbo(auto_model, guide)
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        hand_loss = elbo(hand_model, guide)
        _check_loss_and_grads(hand_loss, auto_loss)
Esempio n. 21
0
def test_model_sample(model, backend):
    with pyro_backend(backend), handlers.seed(
            rng_seed=2), xfail_if_not_implemented():
        f = MODELS[model]()
        model, model_args = f['model'], f.get('model_args', ())
        model(*model_args)