def test_guide_list(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.).expand([2])) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) guide.add(auto_class(poutine.block(model, expose=["x"]), prefix="auto_x")) guide.add(auto_class(poutine.block(model, expose=["y"]), prefix="auto_y")) guide()
def test_callable(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) def guide_x(): x_loc = pyro.param("x_loc", torch.tensor(0.)) pyro.sample("x", dist.Delta(x_loc)) guide = AutoGuideList(model) guide.add(guide_x) guide.add(auto_class(poutine.block(model, expose=["y"]), prefix="auto_y")) values = guide() assert set(values) == set(["y"])
def auto_guide_callable(model): def guide_x(): x_loc = pyro.param("x_loc", torch.tensor(1.)) x_scale = pyro.param("x_scale", torch.tensor(.1), constraint=constraints.positive) pyro.sample("x", dist.Normal(x_loc, x_scale)) def median_x(): return {"x": pyro.param("x_loc", torch.tensor(1.))} guide = AutoGuideList(model) guide.add(AutoCallable(model, guide_x, median_x)) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide
def test_subsample_guide(auto_class, init_fn): # The model from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): num_time_steps = len(batch) result = [None] * num_time_steps drift = pyro.sample("drift", dist.LogNormal(-1, 0.5)) plate = pyro.plate("data", full_size, subsample=subsample) assert plate.size == 50 with plate: z = 0. for t in range(num_time_steps): z = pyro.sample("state_{}".format(t), dist.Normal(z, drift)) result[t] = pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=z), obs=batch[t]) return torch.stack(result) def create_plates(batch, subsample, full_size): return pyro.plate("data", full_size, subsample=subsample) if auto_class == AutoGuideList: guide = AutoGuideList(model, create_plates=create_plates) guide.add(AutoDelta(poutine.block(model, expose=["drift"]))) guide.add(AutoNormal(poutine.block(model, hide=["drift"]))) else: guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 pyro.set_rng_seed(123456789) data = model([None] * num_time_steps, torch.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) pyro.get_param_store().clear() pyro.set_rng_seed(123456789) svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = torch.arange(beg, end) batch = data[:, beg:end] beg = end svi.step(batch, subsample, full_size=full_size)
def test_discrete_parallel(continuous_class): K = 2 data = torch.tensor([0., 1., 10., 11., 12.]) def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.plate('data', len(data)): weights = weights.expand(torch.Size((len(data), )) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.add(continuous_class(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) elbo = TraceEnum_ELBO(max_plate_nesting=1) loss = elbo.loss_and_grads(model, guide, data) assert np.isfinite(loss), loss
def auto_guide_list_x(model): guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=["x"]))) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide