def main(args): funsor.set_backend("torch") # XXX Temporary fix after https://github.com/pyro-ppl/pyro/pull/2701 import pyro pyro.enable_validation(False) encoder = Encoder() decoder = Decoder() encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder) decode = funsor.function(Reals[20], Reals[28, 28])(decoder) @funsor.interpretation(funsor.montecarlo.MonteCarlo()) def loss_function(data, subsample_scale): # Lazily sample from the guide. loc, scale = encode(data) q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'), 'z', 'i', 'z_i') # Evaluate the model likelihood at the lazy value z. probs = decode('z') p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y']) p = p.reduce(ops.add, {'x', 'y'}) # Construct an elbo. This is where sampling happens. elbo = funsor.Integrate(q, p - q, 'z') elbo = elbo.reduce(ops.add, 'batch') * subsample_scale loss = -elbo return loss train_loader = torch.utils.data.DataLoader(datasets.MNIST( DATA_PATH, train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True) encoder.train() decoder.train() optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3) for epoch in range(args.num_epochs): train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): subsample_scale = float(len(train_loader.dataset) / len(data)) data = data[:, 0, :, :] data = funsor.Tensor(data, OrderedDict(batch=Bint[len(data)])) optimizer.zero_grad() loss = loss_function(data, subsample_scale) assert isinstance(loss, funsor.Tensor), loss.pretty() loss.data.backward() train_loss += loss.item() optimizer.step() if batch_idx % 50 == 0: print(' loss = {}'.format(loss.item())) if batch_idx and args.smoke_test: return print('epoch {} train_loss = {}'.format(epoch, train_loss))
def test_function_of_numeric_array(): backend = get_backend() if backend == "torch": import torch matmul = torch.matmul elif backend == "jax": import jax matmul = jax.numpy.matmul else: matmul = np.matmul x = randn((4, 3)) y = randn((3, 2)) f = funsor.function(reals(4, 3), reals(3, 2), reals(4, 2))(matmul) actual = f(x, y) expected = f(Tensor(x), Tensor(y)) assert_close(actual, expected)