Exemple #1
0
 def model(dim):
     y = pyro.sample('y', dist.Normal(0, 3))
     pyro.sample(
         'x',
         dist.TransformedDistribution(
             dist.Normal(ops.zeros(dim - 1), 1),
             dist.transforms.AffineTransform(0, ops.exp(y / 2))))
Exemple #2
0
 def model():
     locs = pyro.param("locs",
                       ops.randn(3),
                       constraint=dist.constraints.real)
     scales = pyro.param("scales",
                         ops.exp(ops.randn(3)),
                         constraint=dist.constraints.positive)
     p = ops.tensor([0.5, 0.3, 0.2])
     x = pyro.sample("x", dist.Categorical(p))
     pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)
Exemple #3
0
def logistic_regression():
    N, dim = 3000, 3
    # generic way to sample from distributions
    data = pyro.sample('data', dist.Normal(0., 1.), sample_shape=(N, dim))
    true_coefs = ops.arange(1., dim + 1.)
    logits = ops.sum(true_coefs * data, axis=-1)
    labels = pyro.sample('labels', dist.Bernoulli(logits=logits))

    def model(x, y=None):
        coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim),
                                                 ops.ones(dim)))
        intercept = pyro.sample('intercept', dist.Normal(0., 1.))
        logits = ops.sum(coefs * x, axis=-1) + intercept
        return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

    return {
        'model': model,
        'model_args': (data, ),
        'model_kwargs': {
            'y': labels
        }
    }
Exemple #4
0
 def model(J, sigma, y=None):
     mu = pyro.sample('mu', dist.Normal(0, 5))
     tau = pyro.sample('tau', dist.HalfCauchy(5))
     with pyro.plate('J', J):
         theta = pyro.sample('theta', dist.Normal(mu, tau))
         pyro.sample('obs', dist.Normal(theta, sigma), obs=y)
Exemple #5
0
 def model(x, y=None):
     coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim),
                                              ops.ones(dim)))
     intercept = pyro.sample('intercept', dist.Normal(0., 1.))
     logits = ops.sum(coefs * x, axis=-1) + intercept
     return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
Exemple #6
0
 def model(data):
     loc = pyro.param("loc", ops.tensor(0.0))
     pyro.sample("x", dist.Normal(loc, 1.), obs=data)
Exemple #7
0
 def model(data=None):
     loc = pyro.param("loc", ops.tensor(2.0))
     scale = pyro.param("scale", ops.tensor(1.0))
     with pyro.plate("data", 1000, dim=-1):
         x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
Exemple #8
0
 def model(data=None):
     loc = pyro.param("loc", ops.tensor(2.0))
     scale = pyro.param("scale", ops.tensor(1.0))
     x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
Exemple #9
0
 def guide():
     loc = pyro.param("loc", ops.tensor(0.))
     x = pyro.sample("x", dist.Normal(loc, 1.))
     pyro.sample("y", dist.Normal(x, 1.))
Exemple #10
0
 def model():
     x = pyro.sample("x", dist.Normal(0., 1.))
     pyro.sample("y", dist.Normal(x, 1.))
Exemple #11
0
 def model():
     locs = pyro.param("locs", ops.tensor([-1., 0., 1.]))
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
Exemple #12
0
 def guide():
     loc = pyro.param("loc", ops.tensor(0.))
     scale = pyro.param("scale", ops.tensor(1.))
     with pyro.plate("plate_outer", data.shape[-1], dim=-1):
         pyro.sample("x", dist.Normal(loc, scale))
Exemple #13
0
 def model():
     loc = ops.tensor(3.0)
     with pyro.plate("plate_outer", data.shape[-1], dim=-1):
         x = pyro.sample("x", dist.Normal(loc, 1.))
         with pyro.plate("plate_inner", data.shape[-2], dim=-2):
             pyro.sample("y", dist.Normal(x, 1.), obs=data)
Exemple #14
0
 def model():
     locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
     p = ops.tensor([0.2, 0.3, 0.5])
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(p))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
Exemple #15
0
 def model():
     pyro.sample("x", dist.Normal(0, 1))