예제 #1
0
def main(args):
    funsor.set_backend("torch")

    # XXX Temporary fix after https://github.com/pyro-ppl/pyro/pull/2701
    import pyro
    pyro.enable_validation(False)

    encoder = Encoder()
    decoder = Decoder()

    encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder)
    decode = funsor.function(Reals[20], Reals[28, 28])(decoder)

    @funsor.interpretation(funsor.montecarlo.MonteCarlo())
    def loss_function(data, subsample_scale):
        # Lazily sample from the guide.
        loc, scale = encode(data)
        q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'),
                               'z', 'i', 'z_i')

        # Evaluate the model likelihood at the lazy value z.
        probs = decode('z')
        p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y'])
        p = p.reduce(ops.add, {'x', 'y'})

        # Construct an elbo. This is where sampling happens.
        elbo = funsor.Integrate(q, p - q, 'z')
        elbo = elbo.reduce(ops.add, 'batch') * subsample_scale
        loss = -elbo
        return loss

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        DATA_PATH, train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True)

    encoder.train()
    decoder.train()
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(decoder.parameters()),
                           lr=1e-3)
    for epoch in range(args.num_epochs):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            subsample_scale = float(len(train_loader.dataset) / len(data))
            data = data[:, 0, :, :]
            data = funsor.Tensor(data, OrderedDict(batch=Bint[len(data)]))

            optimizer.zero_grad()
            loss = loss_function(data, subsample_scale)
            assert isinstance(loss, funsor.Tensor), loss.pretty()
            loss.data.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('  loss = {}'.format(loss.item()))
                if batch_idx and args.smoke_test:
                    return
        print('epoch {} train_loss = {}'.format(epoch, train_loss))
예제 #2
0
def test_function_of_numeric_array():
    backend = get_backend()
    if backend == "torch":
        import torch

        matmul = torch.matmul
    elif backend == "jax":
        import jax

        matmul = jax.numpy.matmul
    else:
        matmul = np.matmul
    x = randn((4, 3))
    y = randn((3, 2))
    f = funsor.function(reals(4, 3), reals(3, 2), reals(4, 2))(matmul)
    actual = f(x, y)
    expected = f(Tensor(x), Tensor(y))
    assert_close(actual, expected)