Ejemplo n.º 1
0
def nested_auto_guide_callable(model):
    guide = AutoGuideList(model)
    guide.append(AutoDelta(poutine.block(model, expose=['x'])))
    guide_y = AutoGuideList(poutine.block(model, expose=['y']))
    guide_y.z = AutoIAFNormal(poutine.block(model, expose=['y']))
    guide.append(guide_y)
    return guide
Ejemplo n.º 2
0
def test_guide_list(auto_class):

    def model():
        pyro.sample("x", dist.Normal(0., 1.).expand([2]))
        pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    guide = AutoGuideList(model)
    guide.append(auto_class(poutine.block(model, expose=["x"])))
    guide.append(auto_class(poutine.block(model, expose=["y"])))
    guide()
Ejemplo n.º 3
0
def auto_guide_callable(model):
    def guide_x():
        x_loc = pyro.param("x_loc", torch.tensor(1.))
        x_scale = pyro.param("x_scale", torch.tensor(.1), constraint=constraints.positive)
        pyro.sample("x", dist.Normal(x_loc, x_scale))

    def median_x():
        return {"x": pyro.param("x_loc", torch.tensor(1.))}

    guide = AutoGuideList(model)
    guide.append(AutoCallable(model, guide_x, median_x))
    guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide
Ejemplo n.º 4
0
def AutoMixed(model_full, init_loc={}, delta=None):
    guide = AutoGuideList(model_full)

    marginalised_guide_block = poutine.block(model_full,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=['tau'])
    if delta is None:
        guide.append(
            AutoNormal(marginalised_guide_block,
                       init_loc_fn=autoguide.init_to_value(values=init_loc),
                       init_scale=0.05))
    elif delta == 'part' or delta == 'all':
        guide.append(
            AutoDelta(marginalised_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    full_rank_guide_block = poutine.block(model_full,
                                          hide_all=True,
                                          expose=['tau'])
    if delta is None or delta == 'part':
        guide.append(
            AutoMultivariateNormal(
                full_rank_guide_block,
                init_loc_fn=autoguide.init_to_value(values=init_loc),
                init_scale=0.05))
    else:
        guide.append(
            AutoDelta(full_rank_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    return guide
Ejemplo n.º 5
0
def test_callable(auto_class):

    def model():
        pyro.sample("x", dist.Normal(0., 1.))
        pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    def guide_x():
        x_loc = pyro.param("x_loc", torch.tensor(0.))
        pyro.sample("x", dist.Delta(x_loc))

    guide = AutoGuideList(model)
    guide.append(guide_x)
    guide.append(auto_class(poutine.block(model, expose=["y"])))
    values = guide()
    assert set(values) == set(["y"])
Ejemplo n.º 6
0
def test_discrete_parallel(continuous_class):
    K = 2
    data = torch.tensor([0., 1., 10., 11., 12.])

    def model(data):
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1))
        scale = pyro.sample('scale', dist.LogNormal(0, 1))

        with pyro.plate('data', len(data)):
            weights = weights.expand(torch.Size((len(data),)) + weights.shape)
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    guide = AutoGuideList(model)
    guide.append(continuous_class(poutine.block(model, hide=["assignment"])))
    guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))

    elbo = TraceEnum_ELBO(max_plate_nesting=1)
    loss = elbo.loss_and_grads(model, guide, data)
    assert np.isfinite(loss), loss
Ejemplo n.º 7
0
def auto_guide_list_x(model):
    guide = AutoGuideList(model)
    guide.append(AutoDelta(poutine.block(model, expose=["x"])))
    guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide