コード例 #1
0
ファイル: test_autoguide.py プロジェクト: jamestwebber/pyro
def test_guide_list(auto_class):
    def model():
        pyro.sample("x", dist.Normal(0., 1.).expand([2]))
        pyro.sample("y",
                    dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    guide = AutoGuideList(model)
    guide.add(auto_class(poutine.block(model, expose=["x"]), prefix="auto_x"))
    guide.add(auto_class(poutine.block(model, expose=["y"]), prefix="auto_y"))
    guide()
コード例 #2
0
ファイル: test_autoguide.py プロジェクト: jamestwebber/pyro
def test_callable(auto_class):
    def model():
        pyro.sample("x", dist.Normal(0., 1.))
        pyro.sample("y",
                    dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    def guide_x():
        x_loc = pyro.param("x_loc", torch.tensor(0.))
        pyro.sample("x", dist.Delta(x_loc))

    guide = AutoGuideList(model)
    guide.add(guide_x)
    guide.add(auto_class(poutine.block(model, expose=["y"]), prefix="auto_y"))
    values = guide()
    assert set(values) == set(["y"])
コード例 #3
0
ファイル: test_autoguide.py プロジェクト: jamestwebber/pyro
def auto_guide_callable(model):
    def guide_x():
        x_loc = pyro.param("x_loc", torch.tensor(1.))
        x_scale = pyro.param("x_scale",
                             torch.tensor(.1),
                             constraint=constraints.positive)
        pyro.sample("x", dist.Normal(x_loc, x_scale))

    def median_x():
        return {"x": pyro.param("x_loc", torch.tensor(1.))}

    guide = AutoGuideList(model)
    guide.add(AutoCallable(model, guide_x, median_x))
    guide.add(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide
コード例 #4
0
def test_subsample_guide(auto_class, init_fn):

    # The model from tutorial/source/easyguide.ipynb
    def model(batch, subsample, full_size):
        num_time_steps = len(batch)
        result = [None] * num_time_steps
        drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
        plate = pyro.plate("data", full_size, subsample=subsample)
        assert plate.size == 50
        with plate:
            z = 0.
            for t in range(num_time_steps):
                z = pyro.sample("state_{}".format(t), dist.Normal(z, drift))
                result[t] = pyro.sample("obs_{}".format(t),
                                        dist.Bernoulli(logits=z),
                                        obs=batch[t])

        return torch.stack(result)

    def create_plates(batch, subsample, full_size):
        return pyro.plate("data", full_size, subsample=subsample)

    if auto_class == AutoGuideList:
        guide = AutoGuideList(model, create_plates=create_plates)
        guide.add(AutoDelta(poutine.block(model, expose=["drift"])))
        guide.add(AutoNormal(poutine.block(model, hide=["drift"])))
    else:
        guide = auto_class(model, create_plates=create_plates)

    full_size = 50
    batch_size = 20
    num_time_steps = 8
    pyro.set_rng_seed(123456789)
    data = model([None] * num_time_steps, torch.arange(full_size), full_size)
    assert data.shape == (num_time_steps, full_size)

    pyro.get_param_store().clear()
    pyro.set_rng_seed(123456789)
    svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
    for epoch in range(2):
        beg = 0
        while beg < full_size:
            end = min(full_size, beg + batch_size)
            subsample = torch.arange(beg, end)
            batch = data[:, beg:end]
            beg = end
            svi.step(batch, subsample, full_size=full_size)
コード例 #5
0
ファイル: test_autoguide.py プロジェクト: jamestwebber/pyro
def test_discrete_parallel(continuous_class):
    K = 2
    data = torch.tensor([0., 1., 10., 11., 12.])

    def model(data):
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        locs = pyro.sample('locs',
                           dist.Normal(0, 10).expand_by([K]).to_event(1))
        scale = pyro.sample('scale', dist.LogNormal(0, 1))

        with pyro.plate('data', len(data)):
            weights = weights.expand(torch.Size((len(data), )) + weights.shape)
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    guide = AutoGuideList(model)
    guide.add(continuous_class(poutine.block(model, hide=["assignment"])))
    guide.add(AutoDiscreteParallel(poutine.block(model,
                                                 expose=["assignment"])))

    elbo = TraceEnum_ELBO(max_plate_nesting=1)
    loss = elbo.loss_and_grads(model, guide, data)
    assert np.isfinite(loss), loss
コード例 #6
0
ファイル: test_autoguide.py プロジェクト: jamestwebber/pyro
def auto_guide_list_x(model):
    guide = AutoGuideList(model)
    guide.add(AutoDelta(poutine.block(model, expose=["x"])))
    guide.add(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide