def nested_auto_guide_callable(model): guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=['x']))) guide_y = AutoGuideList(poutine.block(model, expose=['y'])) guide_y.z = AutoIAFNormal(poutine.block(model, expose=['y'])) guide.append(guide_y) return guide
def test_guide_list(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.).expand([2])) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) guide.append(auto_class(poutine.block(model, expose=["x"]))) guide.append(auto_class(poutine.block(model, expose=["y"]))) guide()
def auto_guide_callable(model): def guide_x(): x_loc = pyro.param("x_loc", torch.tensor(1.)) x_scale = pyro.param("x_scale", torch.tensor(.1), constraint=constraints.positive) pyro.sample("x", dist.Normal(x_loc, x_scale)) def median_x(): return {"x": pyro.param("x_loc", torch.tensor(1.))} guide = AutoGuideList(model) guide.append(AutoCallable(model, guide_x, median_x)) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide
def AutoMixed(model_full, init_loc={}, delta=None): guide = AutoGuideList(model_full) marginalised_guide_block = poutine.block(model_full, expose_all=True, hide_all=False, hide=['tau']) if delta is None: guide.append( AutoNormal(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) elif delta == 'part' or delta == 'all': guide.append( AutoDelta(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) full_rank_guide_block = poutine.block(model_full, hide_all=True, expose=['tau']) if delta is None or delta == 'part': guide.append( AutoMultivariateNormal( full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) else: guide.append( AutoDelta(full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) return guide
def test_callable(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) def guide_x(): x_loc = pyro.param("x_loc", torch.tensor(0.)) pyro.sample("x", dist.Delta(x_loc)) guide = AutoGuideList(model) guide.append(guide_x) guide.append(auto_class(poutine.block(model, expose=["y"]))) values = guide() assert set(values) == set(["y"])
def test_discrete_parallel(continuous_class): K = 2 data = torch.tensor([0., 1., 10., 11., 12.]) def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.plate('data', len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.append(continuous_class(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) elbo = TraceEnum_ELBO(max_plate_nesting=1) loss = elbo.loss_and_grads(model, guide, data) assert np.isfinite(loss), loss
def auto_guide_list_x(model): guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=["x"]))) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide