コード例 #1
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model():
     locs = pyro.param("locs",
                       ops.randn(3),
                       constraint=dist.constraints.real)
     scales = pyro.param("scales",
                         ops.exp(ops.randn(3)),
                         constraint=dist.constraints.positive)
     p = ops.tensor([0.5, 0.3, 0.2])
     x = pyro.sample("x", dist.Categorical(p))
     pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)
コード例 #2
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
def test_local_param_ok(backend, jit):
    data = ops.randn(10)

    def model():
        locs = pyro.param("locs", ops.tensor([-1., 0., 1.]))
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3))
            pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

    def guide():
        with pyro.plate("plate", len(data), dim=-1):
            p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1)
            pyro.sample("x", dist.Categorical(p))
        return p

    Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
    elbo = Elbo(ignore_jit_warnings=True)
    assert_ok(model, guide, elbo)

    # Check that pyro.param() can be called without init_value.
    expected = guide()
    actual = pyro.param("p")
    assert ops.allclose(actual, expected)
コード例 #3
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model(data):
     loc = pyro.param("loc", ops.tensor(0.0))
     pyro.sample("x", dist.Normal(loc, 1.), obs=data)
コード例 #4
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model(data):
     p = pyro.param("p", ops.tensor(0.5))
     pyro.sample("x", dist.Bernoulli(p), obs=data)
コード例 #5
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model(data=None):
     loc = pyro.param("loc", ops.tensor(2.0))
     scale = pyro.param("scale", ops.tensor(1.0))
     with pyro.plate("data", 1000, dim=-1):
         x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
コード例 #6
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model(data=None):
     loc = pyro.param("loc", ops.tensor(2.0))
     scale = pyro.param("scale", ops.tensor(1.0))
     x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
コード例 #7
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def guide():
     loc = pyro.param("loc", ops.tensor(0.))
     x = pyro.sample("x", dist.Normal(loc, 1.))
     pyro.sample("y", dist.Normal(x, 1.))
コード例 #8
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def guide():
     q = pyro.param("q",
                    ops.exp(ops.randn(3)),
                    constraint=dist.constraints.simplex)
     pyro.sample("x", dist.Categorical(q))
コード例 #9
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def guide():
     with pyro.plate("plate", len(data), dim=-1):
         p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1)
         pyro.sample("x", dist.Categorical(p))
     return p
コード例 #10
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model():
     locs = pyro.param("locs", ops.tensor([-1., 0., 1.]))
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
コード例 #11
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def guide():
     loc = pyro.param("loc", ops.tensor(0.))
     scale = pyro.param("scale", ops.tensor(1.))
     with pyro.plate("plate_outer", data.shape[-1], dim=-1):
         pyro.sample("x", dist.Normal(loc, scale))
コード例 #12
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def guide():
     p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2]))
     with pyro.plate("plate", len(data), dim=-1):
         pyro.sample("x", dist.Categorical(p))
コード例 #13
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
 def model():
     locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
     p = ops.tensor([0.2, 0.3, 0.5])
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(p))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)