Ejemplo n.º 1
0
def assert_ok(model, guide, elbo, *args, **kwargs):
    """
    Assert that inference works without warnings or errors.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    for i in range(2):
        inference.step(*args, **kwargs)
Ejemplo n.º 2
0
def assert_error(model, guide, elbo, match=None):
    """
    Assert that inference fails with an error.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model,  guide, adam, elbo)
    with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError),
                       match=match):
        inference.step()
Ejemplo n.º 3
0
def assert_warning(model, guide, elbo):
    """
    Assert that inference works but with a warning.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        inference.step()
        assert len(w), 'No warnings were raised'
        for warning in w:
            print(warning)
Ejemplo n.º 4
0
def test_gaussian_probit_hmm_smoke(exact, jit):
    def model(data):
        T, N, D = data.shape  # time steps, individuals, features

        # Gaussian initial distribution.
        init_loc = pyro.param("init_loc", torch.zeros(D))
        init_scale = pyro.param("init_scale",
                                1e-2 * torch.eye(D),
                                constraint=constraints.lower_cholesky)

        # Linear dynamics with Gaussian noise.
        trans_const = pyro.param("trans_const", torch.zeros(D))
        trans_coeff = pyro.param("trans_coeff", torch.eye(D))
        noise = pyro.param("noise",
                           1e-2 * torch.eye(D),
                           constraint=constraints.lower_cholesky)

        obs_plate = pyro.plate("channel", D, dim=-1)
        with pyro.plate("data", N, dim=-2):
            state = None
            for t in range(T):
                # Transition.
                if t == 0:
                    loc = init_loc
                    scale_tril = init_scale
                else:
                    loc = trans_const + funsor.torch.torch_tensordot(
                        trans_coeff, state, 1)
                    scale_tril = noise
                state = pyro.sample("state_{}".format(t),
                                    dist.MultivariateNormal(loc, scale_tril),
                                    infer={"exact": exact})

                # Factorial probit likelihood model.
                with obs_plate:
                    pyro.sample("obs_{}".format(t),
                                dist.Bernoulli(logits=state["channel"]),
                                obs=data[t])

    def guide(data):
        pass

    data = torch.distributions.Bernoulli(0.5).sample((3, 4, 2))

    with pyro_backend("funsor"):
        Elbo = infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": 1e-3})
        svi = infer.SVI(model, guide, adam, elbo)
        svi.step(data)
Ejemplo n.º 5
0
def main(args):
    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale",
                                 torch.tensor(1.),
                                 constraint=constraints.positive)
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args.backend), interpretation(monte_carlo):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": args.learning_rate})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args.num_steps):
            loss = svi.step(data)
            if args.verbose and step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        if args.verbose:
            for name in pyro.get_param_store():
                value = pyro.param(name).data
                print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
Ejemplo n.º 6
0
def test_optimizer(backend, optim_name, jit):
    def model(data):
        p = pyro.param("p", torch.tensor(0.5))
        pyro.sample("x", dist.Bernoulli(p), obs=data)

    def guide(data):
        pass

    data = torch.tensor(0.)
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        optimizer = getattr(optim, optim_name)({"lr": 1e-6})
        inference = infer.SVI(model, guide, optimizer, elbo)
        for i in range(2):
            inference.step(data)
Ejemplo n.º 7
0
def build_svi(model, guide, elbo):
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    return infer.SVI(model, guide, adam, elbo)