예제 #1
0
def logistic_regression():
    N, dim = 3000, 3
    # generic way to sample from distributions
    data = pyro.sample('data', dist.Normal(0., 1.), sample_shape=(N, dim))
    true_coefs = ops.arange(1., dim + 1.)
    logits = ops.sum(true_coefs * data, axis=-1)
    labels = pyro.sample('labels', dist.Bernoulli(logits=logits))

    def model(x, y=None):
        coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim),
                                                 ops.ones(dim)))
        intercept = pyro.sample('intercept', dist.Normal(0., 1.))
        logits = ops.sum(coefs * x, axis=-1) + intercept
        return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

    return {
        'model': model,
        'model_args': (data, ),
        'model_kwargs': {
            'y': labels
        }
    }
예제 #2
0
 def model(x, y=None):
     coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim),
                                              ops.ones(dim)))
     intercept = pyro.sample('intercept', dist.Normal(0., 1.))
     logits = ops.sum(coefs * x, axis=-1) + intercept
     return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
예제 #3
0
 def model(data):
     p = pyro.param("p", ops.tensor(0.5))
     pyro.sample("x", dist.Bernoulli(p), obs=data)