Exemple #1
0
 def auto_guide(data):
     probs_a = pyro.param("guide_probs_a")
     probs_c = pyro.param("guide_probs_c")
     a = pyro.sample("a",
                     dist.Categorical(probs_a),
                     infer={"enumerate": "parallel"})
     with pyro.plate("data", 2, dim=-1):
         pyro.sample("c", dist.Categorical(probs_c[a]))
Exemple #2
0
def constrained_model(data):
    locs = pyro.param("locs", torch.randn(3), constraint=constraints.real)
    scales = pyro.param("scales",
                        ops.exp(torch.randn(3)),
                        constraint=constraints.positive)
    p = torch.tensor([0.5, 0.3, 0.2])
    x = pyro.sample("x", dist.Categorical(p))
    pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)
Exemple #3
0
 def hand_guide(data):
     probs_a = pyro.param("guide_probs_a")
     probs_c = pyro.param("guide_probs_c")
     a = pyro.sample("a",
                     dist.Categorical(probs_a),
                     infer={"enumerate": "parallel"})
     for i in range(2):
         pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))
Exemple #4
0
 def guide():
     d = dist.Categorical(pyro.param("q"))
     context1 = pyro.plate("outer", outer_dim, dim=-1)
     context2 = pyro.plate("inner", inner_dim, dim=-2)
     pyro.sample("w", d, infer={"enumerate": "parallel"})
     with context1:
         pyro.sample("x", d, infer={"enumerate": "parallel"})
     with context2:
         pyro.sample("y", d, infer={"enumerate": "parallel"})
     with context1, context2:
         pyro.sample("z", d, infer={"enumerate": "parallel"})
Exemple #5
0
 def model():
     d = dist.Categorical(p)
     context1 = pyro.plate("outer", outer_dim, dim=-1)
     context2 = pyro.plate("inner", inner_dim, dim=-2)
     pyro.sample("w", d)
     with context1:
         pyro.sample("x", d)
     with context2:
         pyro.sample("y", d)
     with context1, context2:
         pyro.sample("z", d)
Exemple #6
0
    def model(data):
        T, N, D = data.shape  # time steps, individuals, features

        # Gaussian initial distribution.
        init_loc = pyro.param("init_loc", torch.zeros(D))
        init_scale = pyro.param("init_scale",
                                1e-2 * torch.eye(D),
                                constraint=constraints.lower_cholesky)

        # Linear dynamics with Gaussian noise.
        trans_const = pyro.param("trans_const", torch.zeros(D))
        trans_coeff = pyro.param("trans_coeff", torch.eye(D))
        noise = pyro.param("noise",
                           1e-2 * torch.eye(D),
                           constraint=constraints.lower_cholesky)

        obs_plate = pyro.plate("channel", D, dim=-1)
        with pyro.plate("data", N, dim=-2):
            state = None
            for t in range(T):
                # Transition.
                if t == 0:
                    loc = init_loc
                    scale_tril = init_scale
                else:
                    loc = trans_const + funsor.torch.torch_tensordot(
                        trans_coeff, state, 1)
                    scale_tril = noise
                state = pyro.sample("state_{}".format(t),
                                    dist.MultivariateNormal(loc, scale_tril),
                                    infer={"exact": exact})

                # Factorial probit likelihood model.
                with obs_plate:
                    pyro.sample("obs_{}".format(t),
                                dist.Bernoulli(logits=state["channel"]),
                                obs=data[t])
Exemple #7
0
 def auto_model():
     probs_a = pyro.param("probs_a")
     probs_b = pyro.param("probs_b")
     probs_c = pyro.param("probs_c")
     probs_d = pyro.param("probs_d")
     with pyro.plate("a_axis", 2, dim=-1):
         a = pyro.sample("a",
                         dist.Categorical(probs_a),
                         infer={"enumerate": "parallel"})
         pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data)
     with pyro.plate("c_axis", 3, dim=-1):
         c = pyro.sample("c",
                         dist.Categorical(probs_c),
                         infer={"enumerate": "parallel"})
         pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data)
Exemple #8
0
 def auto_model(data):
     probs_a = pyro.param("model_probs_a")
     probs_b = pyro.param("model_probs_b")
     probs_c = pyro.param("model_probs_c")
     probs_d = pyro.param("model_probs_d")
     probs_e = pyro.param("model_probs_e")
     a = pyro.sample("a", dist.Categorical(probs_a))
     b = pyro.sample("b",
                     dist.Categorical(probs_b[a]),
                     infer={"enumerate": "parallel"})
     with pyro.plate("data", 2, dim=-1):
         c = pyro.sample("c", dist.Categorical(probs_c[a]))
         d = pyro.sample("d",
                         dist.Categorical(Vindex(probs_d)[b, c]),
                         infer={"enumerate": "parallel"})
         pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data)
Exemple #9
0
 def hand_model(data):
     probs_a = pyro.param("model_probs_a")
     probs_b = pyro.param("model_probs_b")
     probs_c = pyro.param("model_probs_c")
     probs_d = pyro.param("model_probs_d")
     probs_e = pyro.param("model_probs_e")
     a = pyro.sample("a", dist.Categorical(probs_a))
     b = pyro.sample("b",
                     dist.Categorical(probs_b[a]),
                     infer={"enumerate": "parallel"})
     for i in range(2):
         c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))
         d = pyro.sample("d_{}".format(i),
                         dist.Categorical(Vindex(probs_d)[b, c]),
                         infer={"enumerate": "parallel"})
         pyro.sample("obs_{}".format(i),
                     dist.Categorical(probs_e[d]),
                     obs=data[i])
Exemple #10
0
 def hand_model():
     probs_a = pyro.param("probs_a")
     probs_b = pyro.param("probs_b")
     probs_c = pyro.param("probs_c")
     probs_d = pyro.param("probs_d")
     for i in range(2):
         a = pyro.sample("a_{}".format(i),
                         dist.Categorical(probs_a),
                         infer={"enumerate": "parallel"})
         pyro.sample("b_{}".format(i),
                     dist.Categorical(probs_b[a]),
                     obs=b_data[i])
     for j in range(3):
         c = pyro.sample("c_{}".format(j),
                         dist.Categorical(probs_c),
                         infer={"enumerate": "parallel"})
         pyro.sample("d_{}".format(j),
                     dist.Categorical(probs_d[c]),
                     obs=d_data[j])
Exemple #11
0
 def guide():
     with pyro.plate("plate", len(data), dim=-1):
         p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1)
         pyro.sample("x", dist.Categorical(p))
     return p
Exemple #12
0
 def model():
     locs = pyro.param("locs", torch.tensor([-1.0, 0.0, 1.0]))
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
         pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data)
Exemple #13
0
 def guide():
     loc = pyro.param("loc", torch.tensor(0.0))
     scale = pyro.param("scale", torch.tensor(1.0))
     with pyro.plate("plate_outer", data.size(-1), dim=-1):
         pyro.sample("x", dist.Normal(loc, scale))
Exemple #14
0
 def model():
     loc = torch.tensor(3.0)
     with pyro.plate("plate_outer", data.size(-1), dim=-1):
         x = pyro.sample("x", dist.Normal(loc, 1.0))
         with pyro.plate("plate_inner", data.size(-2), dim=-2):
             pyro.sample("y", dist.Normal(x, 1.0), obs=data)
Exemple #15
0
 def guide():
     p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2]))
     with pyro.plate("plate", len(data), dim=-1):
         pyro.sample("x", dist.Categorical(p))
Exemple #16
0
 def model():
     locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5]))
     p = torch.tensor([0.2, 0.3, 0.5])
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(p))
         pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data)
Exemple #17
0
 def model(data):
     p = pyro.param("p", torch.tensor(0.5))
     pyro.sample("x", dist.Bernoulli(p), obs=data)
Exemple #18
0
 def model(data):
     loc = pyro.sample("loc", dist.Normal(0., 1.))
     with pyro.plate("data", len(data), dim=-1):
         pyro.sample("obs", dist.Normal(loc, 1.), obs=data)
Exemple #19
0
 def model(data=None):
     loc = pyro.param("loc", torch.tensor(2.0))
     scale = pyro.param("scale", torch.tensor(1.0))
     x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
Exemple #20
0
 def model():
     x = pyro.sample("x", dist.Normal(0., 1.))
     pyro.sample("y", dist.Normal(x, 1.))
Exemple #21
0
 def guide():
     q = pyro.param("q",
                    torch.randn(3).exp(),
                    constraint=constraints.simplex)
     pyro.sample("x", dist.Categorical(q))
Exemple #22
0
 def model(data):
     loc = pyro.param("loc", torch.tensor(0.0))
     pyro.sample("x", dist.Normal(loc, 1.0), obs=data)
Exemple #23
0
def guide_constrained_model(data):
    q = pyro.param("q",
                   ops.exp(torch.randn(3)),
                   constraint=constraints.simplex)
    pyro.sample("x", dist.Categorical(q))
Exemple #24
0
 def guide():
     loc = pyro.param("loc", torch.tensor(0.))
     y = pyro.sample("y", dist.Normal(loc, 1.))
     pyro.sample("x", dist.Normal(y, 1.))
Exemple #25
0
 def model(data=None):
     loc = pyro.param("loc", torch.tensor(expected_mean))
     scale = pyro.param("scale", torch.tensor(1.0))
     with pyro.plate("data", 1000, dim=-1):
         x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
Exemple #26
0
 def guide(data):
     guide_loc = pyro.param("guide_loc", torch.tensor(0.))
     guide_scale = pyro.param("guide_scale",
                              torch.tensor(1.),
                              constraint=constraints.positive)
     pyro.sample("loc", dist.Normal(guide_loc, guide_scale))
Exemple #27
0
 def guide(data):
     guide_loc = pyro.param("guide_loc", torch.tensor(0.))
     guide_scale = ops.exp(pyro.param("guide_scale_log", torch.tensor(0.)))
     pyro.sample("loc", dist.Normal(guide_loc, guide_scale))
 def guide(data):
     guide_loc = pyro.param("guide_loc", torch.tensor(0.))
     guide_scale = pyro.param(
         "guide_scale_log", torch.tensor(0.),
         torch.distributions.constraints.positive).exp()
     pyro.sample("loc", dist.Normal(guide_loc, guide_scale))