def model(): locs = pyro.param("locs", ops.randn(3), constraint=dist.constraints.real) scales = pyro.param("scales", ops.exp(ops.randn(3)), constraint=dist.constraints.positive) p = ops.tensor([0.5, 0.3, 0.2]) x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)
def guide(): q = pyro.param("q", ops.exp(ops.randn(3)), constraint=dist.constraints.simplex) pyro.sample("x", dist.Categorical(q))
def guide(): with pyro.plate("plate", len(data), dim=-1): p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1) pyro.sample("x", dist.Categorical(p)) return p
def model(): locs = pyro.param("locs", ops.tensor([-1., 0., 1.])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
def guide(): p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2])) with pyro.plate("plate", len(data), dim=-1): pyro.sample("x", dist.Categorical(p))
def model(): locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5])) p = ops.tensor([0.2, 0.3, 0.5]) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)