Example #1
0
 def model(z1=None, z2=None):
     p = pyro.param("p", torch.tensor([0.25, 0.75]))
     loc = pyro.param("loc", torch.tensor([-1.0, 1.0]))
     z1 = pyro.sample("z1", dist.Categorical(p), obs=z1)
     with pyro.plate("data[0]", 3):
         pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
     with pyro.plate("data[1]", 2):
         z2 = pyro.sample("z2", dist.Categorical(p), obs=z2)
         pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
Example #2
0
 def model(z1=None, z2=None):
     p = pyro.param("p", torch.tensor([[0.25, 0.75], [0.1, 0.9]]))
     loc = pyro.param("loc", torch.tensor([-1.0, 1.0]))
     z1 = pyro.sample("z1", dist.Categorical(p[0]), obs=z1)
     z2 = pyro.sample("z2", dist.Categorical(p[z1]), obs=z2)
     logger.info("z1.shape = {}".format(z1.shape))
     logger.info("z2.shape = {}".format(z2.shape))
     with pyro.plate("data", 3):
         pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
         pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
Example #3
0
def model2():

    data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])]
    p = pyro.param("p", torch.tensor([0.25, 0.75]))
    loc = pyro.sample("loc", dist.Normal(0, 1).expand([2]).to_event(1))
    # FIXME results in infinite loop in transformeddist_to_funsor.
    # scale = pyro.sample("scale", dist.LogNormal(0, 1))
    z1 = pyro.sample("z1", dist.Categorical(p))
    scale = pyro.sample("scale", dist.Normal(torch.tensor([0.0, 1.0])[z1],
                                             1)).exp()
    with pyro.plate("data[0]", 3):
        pyro.sample("x1", dist.Normal(loc[z1], scale), obs=data[0])
    with pyro.plate("data[1]", 2):
        z2 = pyro.sample("z2", dist.Categorical(p))
        pyro.sample("x2", dist.Normal(loc[z2], scale), obs=data[1])
Example #4
0
def model_0(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    with pyro.plate("sequences", data.shape[0], dim=-3) as sequences:
        sequences = sequences[:, None]
        x_prev = None
        markov_loop = \
            pyro.vectorized_markov(name="time", size=data.shape[1], dim=-2, history=history) if vectorized \
            else pyro.markov(range(data.shape[1]), history=history)
        for i in markov_loop:
            x_curr = pyro.sample(
                "x_{}".format(i),
                dist.Categorical(
                    init if isinstance(i, int) and i < 1 else trans[x_prev]))
            with pyro.plate("tones", data.shape[2], dim=-1):
                pyro.sample("y_{}".format(i),
                            dist.Normal(Vindex(locs)[..., x_curr], 1.),
                            obs=Vindex(data)[sequences, i])
            x_prev = x_curr
Example #5
0
 def model(z=None):
     p = pyro.param("p", torch.tensor([0.75, 0.25]))
     iz = pyro.sample("z", dist.Categorical(p), obs=z)
     z = torch.tensor([0.0, 1.0])[iz]
     logger.info("z.shape = {}".format(z.shape))
     with pyro.plate("data", 3):
         pyro.sample("x", dist.Normal(z, 1.0), obs=data)
Example #6
0
def model_1(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                init if isinstance(i, int) and i < 1 else trans[x_prev]),
        )
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample(
                "y_{}".format(i),
                dist.Normal(Vindex(locs)[..., x_curr], 1.0),
                obs=data[i],
            )
        x_prev = x_curr
Example #7
0
def model_6(data, history, vectorized):
    x_dim = 3
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((len(data) - 1, x_dim, x_dim)),
                         constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        if isinstance(i, int) and i < 1:
            x_probs = x_init
        elif isinstance(i, int):
            x_probs = x_trans[i - 1, x_prev]
        else:
            x_probs = Vindex(x_trans)[(i - 1)[:, None], x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Normal(Vindex(locs)[..., x_curr], 1.),
                        obs=data[i])
        x_prev = x_curr
Example #8
0
 def model():
     locs = pyro.param("locs", torch.randn(3), constraint=constraints.real)
     scales = pyro.param("scales",
                         torch.randn(3).exp(),
                         constraint=constraints.positive)
     p = torch.tensor([0.5, 0.3, 0.2])
     x = pyro.sample("x", dist.Categorical(p))
     pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)
Example #9
0
def model_zzxx():
    #                  loc,scale
    #                 /         \
    #       +-------/-+  +--------\------+
    #  z1 --|--> x1   |  |  z2 ---> x2   |
    #       |       3 |  |             2 |
    #       +---------+  +---------------+
    data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])]
    p = pyro.param("p", torch.tensor([0.25, 0.75]))
    loc = pyro.sample("loc", dist.Normal(0, 1).expand([2]).to_event(1))
    # FIXME results in infinite loop in transformeddist_to_funsor.
    # scale = pyro.sample("scale", dist.LogNormal(0, 1))
    scale = pyro.sample("scale", dist.Normal(0, 1)).exp()
    z1 = pyro.sample("z1", dist.Categorical(p))
    with pyro.plate("data[0]", 3):
        pyro.sample("x1", dist.Normal(loc[z1], scale), obs=data[0])
    with pyro.plate("data[1]", 2):
        z2 = pyro.sample("z2", dist.Categorical(p))
        pyro.sample("x2", dist.Normal(loc[z2], scale), obs=data[1])
Example #10
0
 def hmm(data, hidden_dim=10):
     transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
     means = torch.arange(float(hidden_dim))
     states = [0]
     for t in pyro.markov(range(len(data))):
         states.append(
             pyro.sample("states_{}".format(t),
                         dist.Categorical(transition[states[-1]])))
         data[t] = pyro.sample("obs_{}".format(t),
                               dist.Normal(means[states[-1]], 1.0),
                               obs=data[t])
     return states, data
Example #11
0
 def model():
     loc = pyro.param("loc", torch.tensor(2.0))
     scale = pyro.param("scale", torch.tensor(1.0))
     x = pyro.sample("x", dist.Normal(loc, scale))
     return x
Example #12
0
 def guide():
     loc = pyro.param("loc", torch.tensor(0.))
     y = pyro.sample("y", dist.Normal(loc, 1.))
     pyro.sample("x", dist.Normal(y, 1.))
Example #13
0
 def model():
     x = pyro.sample("x", dist.Normal(0., 1.))
     pyro.sample("y", dist.Normal(x, 1.))
Example #14
0
 def model():
     locs = pyro.param("locs", torch.tensor([-1., 0., 1.]))
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
Example #15
0
 def model(z=None):
     p = pyro.param("p", torch.tensor([0.75, 0.25]))
     z = pyro.sample("z", dist.Categorical(p), obs=z)
     logger.info("z.shape = {}".format(z.shape))
     with pyro.plate("data", 3), handlers.mask(mask=mask):
         pyro.sample("x", dist.Normal(z.type_as(data), 1.0), obs=data)
Example #16
0
 def model(data):
     loc = pyro.sample("loc", dist.Normal(0., 1.))
     with pyro.plate("data", len(data), dim=-1):
         pyro.sample("obs", dist.Normal(loc, 1.), obs=data)
Example #17
0
 def model(data):
     loc = pyro.param("loc", torch.tensor(0.0))
     pyro.sample("x", dist.Normal(loc, 1.), obs=data)
Example #18
0
 def model(data=None):
     loc = pyro.param("loc", torch.tensor(2.0))
     scale = pyro.param("scale", torch.tensor(1.0))
     with pyro.plate("data", 1000, dim=-1):
         x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
Example #19
0
 def model():
     return pyro.sample("x", dist.Normal(0, 1))
Example #20
0
 def model():
     locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5]))
     p = torch.tensor([0.2, 0.3, 0.5])
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(p))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
Example #21
0
 def guide():
     loc = pyro.param("loc", torch.tensor(0.))
     scale = pyro.param("scale", torch.tensor(1.))
     with pyro.plate("plate_outer", data.size(-1), dim=-1):
         pyro.sample("x", dist.Normal(loc, scale))
Example #22
0
 def guide(data):
     guide_loc = pyro.param("guide_loc", torch.tensor(0.))
     guide_scale = pyro.param("guide_scale", torch.tensor(1.),
                              constraint=constraints.positive)
     pyro.sample("loc", dist.Normal(guide_loc, guide_scale))
Example #23
0
 def model():
     loc = torch.tensor(3.0)
     with pyro.plate("plate_outer", data.size(-1), dim=-1):
         x = pyro.sample("x", dist.Normal(loc, 1.))
         with pyro.plate("plate_inner", data.size(-2), dim=-2):
             pyro.sample("y", dist.Normal(x, 1.), obs=data)