def model(): locs = pyro.param("locs", ops.randn(3), constraint=dist.constraints.real) scales = pyro.param("scales", ops.exp(ops.randn(3)), constraint=dist.constraints.positive) p = ops.tensor([0.5, 0.3, 0.2]) x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)
def test_local_param_ok(backend, jit): data = ops.randn(10) def model(): locs = pyro.param("locs", ops.tensor([-1., 0., 1.])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) def guide(): with pyro.plate("plate", len(data), dim=-1): p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1) pyro.sample("x", dist.Categorical(p)) return p Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) assert_ok(model, guide, elbo) # Check that pyro.param() can be called without init_value. expected = guide() actual = pyro.param("p") assert ops.allclose(actual, expected)
def model(data): loc = pyro.param("loc", ops.tensor(0.0)) pyro.sample("x", dist.Normal(loc, 1.), obs=data)
def model(data): p = pyro.param("p", ops.tensor(0.5)) pyro.sample("x", dist.Bernoulli(p), obs=data)
def model(data=None): loc = pyro.param("loc", ops.tensor(2.0)) scale = pyro.param("scale", ops.tensor(1.0)) with pyro.plate("data", 1000, dim=-1): x = pyro.sample("x", dist.Normal(loc, scale), obs=data) return x
def model(data=None): loc = pyro.param("loc", ops.tensor(2.0)) scale = pyro.param("scale", ops.tensor(1.0)) x = pyro.sample("x", dist.Normal(loc, scale), obs=data) return x
def guide(): loc = pyro.param("loc", ops.tensor(0.)) x = pyro.sample("x", dist.Normal(loc, 1.)) pyro.sample("y", dist.Normal(x, 1.))
def guide(): q = pyro.param("q", ops.exp(ops.randn(3)), constraint=dist.constraints.simplex) pyro.sample("x", dist.Categorical(q))
def guide(): with pyro.plate("plate", len(data), dim=-1): p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1) pyro.sample("x", dist.Categorical(p)) return p
def model(): locs = pyro.param("locs", ops.tensor([-1., 0., 1.])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
def guide(): loc = pyro.param("loc", ops.tensor(0.)) scale = pyro.param("scale", ops.tensor(1.)) with pyro.plate("plate_outer", data.shape[-1], dim=-1): pyro.sample("x", dist.Normal(loc, scale))
def guide(): p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2])) with pyro.plate("plate", len(data), dim=-1): pyro.sample("x", dist.Categorical(p))
def model(): locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5])) p = ops.tensor([0.2, 0.3, 0.5]) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)