Ejemplo n.º 1
0
def test_em_nested_in_svi(assignment_grad):
    args = make_args()
    args.assignment_grad = assignment_grad
    detections = generate_data(args)

    pyro.clear_param_store()
    pyro.param('noise_scale',
               torch.tensor(args.init_noise_scale),
               constraint=constraints.positive)
    pyro.param('objects_loc', torch.randn(args.max_num_objects, 1))

    # Learn object_loc via EM and noise_scale via SVI.
    optim = Adam({'lr': 0.1})
    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    newton = Newton(trust_radii={'objects_loc': 1.0})
    svi = SVI(poutine.block(model, hide=['objects_loc']),
              poutine.block(guide, hide=['objects_loc']), optim, elbo)
    for svi_step in range(50):
        for em_step in range(2):
            objects_loc = pyro.param('objects_loc').detach_().requires_grad_()
            assert pyro.param('objects_loc').grad_fn is None
            loss = elbo.differentiable_loss(model, guide, detections,
                                            args)  # E-step
            updated = newton.get_step(loss,
                                      {'objects_loc': objects_loc})  # M-step
            assert updated['objects_loc'].grad_fn is not None
            pyro.get_param_store()['objects_loc'] = updated['objects_loc']
            assert pyro.param('objects_loc').grad_fn is not None
        loss = svi.step(detections, args)
        logger.debug(
            'step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format(
                svi_step, loss,
                pyro.param('noise_scale').item()))
Ejemplo n.º 2
0
def test_prob(nderivs):
    #      +-------+
    #  z --|--> x  |
    #      +-------+
    num_particles = 10000
    data = torch.tensor([0.5, 1., 1.5])
    p = pyro.param("p", torch.tensor(0.25))

    @config_enumerate
    def model(num_particles):
        p = pyro.param("p")
        with pyro.plate("num_particles", num_particles, dim=-2):
            z = pyro.sample("z", dist.Bernoulli(p))
            with pyro.plate("data", 3):
                pyro.sample("x", dist.Normal(z, 1.), obs=data)

    def guide(num_particles):
        pass

    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1)

    posterior_model = infer_discrete(config_enumerate(model, "parallel"),
                                     first_available_dim=-3)
    posterior_trace = poutine.trace(posterior_model).get_trace(
        num_particles=num_particles)
    actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2)

    if nderivs == 0:
        assert_equal(expected_logprob, actual_logprob, prec=1e-3)
    elif nderivs == 1:
        expected_grad = grad(expected_logprob, [p])[0]
        actual_grad = grad(actual_logprob, [p])[0]
        assert_equal(expected_grad, actual_grad, prec=1e-3)
Ejemplo n.º 3
0
def test_enum_discrete_parallel_iarange_ok():
    enum_discrete = "defined below"

    def model():
        p2 = torch.ones(2) / 2
        p34 = torch.ones(3, 4) / 4
        p536 = torch.ones(5, 3, 6) / 6

        x2 = pyro.sample("x2", dist.Categorical(p2))
        with pyro.iarange("outer", 3):
            x34 = pyro.sample("x34", dist.Categorical(p34))
            with pyro.iarange("inner", 5):
                x536 = pyro.sample("x536", dist.Categorical(p536))

        if enum_discrete == "sequential":
            # All dimensions are iarange dimensions.
            assert x2.shape == torch.Size([])
            assert x34.shape == torch.Size([3])
            assert x536.shape == torch.Size([5, 3])
        else:
            # Meaning of dimensions:    [ enum dims | iarange dims ]
            assert x2.shape == torch.Size([2, 1, 1])  # noqa: E201
            assert x34.shape == torch.Size([4, 1, 1, 3])  # noqa: E201
            assert x536.shape == torch.Size([6, 1, 1, 5, 3])  # noqa: E201

    enum_discrete = "sequential"
    assert_ok(model, config_enumerate(model, "sequential"),
              TraceEnum_ELBO(max_iarange_nesting=2))

    enum_discrete = "parallel"
    assert_ok(model, config_enumerate(model, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=2))
Ejemplo n.º 4
0
def test_svi_multi():
    args = make_args()
    args.assignment_grad = True
    detections = generate_data(args)

    pyro.clear_param_store()
    pyro.param('noise_scale',
               torch.tensor(args.init_noise_scale),
               constraint=constraints.positive)
    pyro.param('objects_loc', torch.randn(args.max_num_objects, 1))

    # Learn object_loc via Newton and noise_scale via Adam.
    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    adam = Adam({'lr': 0.1})
    newton = Newton(trust_radii={'objects_loc': 1.0})
    optim = MixedMultiOptimizer([(['noise_scale'], adam),
                                 (['objects_loc'], newton)])
    for svi_step in range(50):
        with poutine.trace(param_only=True) as param_capture:
            loss = elbo.differentiable_loss(model, guide, detections, args)
        params = {
            name: pyro.param(name).unconstrained()
            for name in param_capture.trace.nodes.keys()
        }
        optim.step(loss, params)
        logger.debug(
            'step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format(
                svi_step, loss.item(),
                pyro.param('noise_scale').item()))
Ejemplo n.º 5
0
def test_svi_enum(plate_dim, enumerate1, enumerate2):
    pyro.clear_param_store()
    num_particles = 10
    q = pyro.param("q", constant(0.75), constraint=constraints.unit_interval)
    p = 0.2693204236205713  # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5

    def model():
        pyro.sample("x", dist.Bernoulli(p))
        for i in pyro.plate("plate", plate_dim):
            pyro.sample("y_{}".format(i), dist.Bernoulli(p))

    def guide():
        q = pyro.param("q")
        pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1})
        for i in pyro.plate("plate", plate_dim):
            pyro.sample(
                "y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2}
            )

    kl = (1 + plate_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p))
    expected_loss = kl.item()
    expected_grad = grad(kl, [q.unconstrained()])[0]

    inner_particles = 2
    outer_particles = num_particles // inner_particles
    elbo = TraceEnum_ELBO(
        max_plate_nesting=0,
        strict_enumeration_warning=any([enumerate1, enumerate2]),
        num_particles=inner_particles,
        ignore_jit_warnings=True,
    )
    actual_loss = (
        sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles))
        / outer_particles
    )
    actual_grad = q.unconstrained().grad / outer_particles

    assert_equal(
        actual_loss,
        expected_loss,
        prec=0.3,
        msg="".join(
            [
                "\nexpected loss = {}".format(expected_loss),
                "\n  actual loss = {}".format(actual_loss),
            ]
        ),
    )
    assert_equal(
        actual_grad,
        expected_grad,
        prec=0.5,
        msg="".join(
            [
                "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()),
                "\n  actual grad = {}".format(actual_grad.detach().cpu().numpy()),
            ]
        ),
    )
Ejemplo n.º 6
0
def test_enum_discrete_iarange_dependency_warning(enumerate_, is_validate):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        with pyro.iarange("iarange", 10, 5):
            x = pyro.sample("x",
                            dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})
        pyro.sample("y",
                    dist.Bernoulli(x.mean()))  # user should move this line up

    with pyro.validation_enabled(is_validate):
        if enumerate_ and is_validate:
            assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
        else:
            assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Ejemplo n.º 7
0
def test_enum_discrete_iarange_shape_broadcasting_ok(enumerate_):
    @poutine.broadcast
    @config_enumerate(default=enumerate_)
    def model():
        x_iarange = pyro.iarange("x_iarange", 10, 5, dim=-1)
        y_iarange = pyro.iarange("y_iarange", 11, 6, dim=-2)
        with pyro.iarange("num_particles", 50, dim=-3):
            with x_iarange:
                b = pyro.sample(
                    "b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1)))
                assert b.shape == torch.Size((50, 1, 5))
            with y_iarange:
                c = pyro.sample("c", dist.Bernoulli(0.5))
                if enumerate_ == "parallel":
                    assert c.shape == torch.Size((2, 50, 6, 1))
                else:
                    assert c.shape == torch.Size((50, 6, 1))
            with x_iarange, y_iarange:
                d = pyro.sample("d", dist.Bernoulli(b))
                if enumerate_ == "parallel":
                    assert d.shape == torch.Size((2, 1, 50, 6, 5))
                else:
                    assert d.shape == torch.Size((50, 6, 5))

    assert_ok(
        model, model,
        TraceEnum_ELBO(max_iarange_nesting=3,
                       strict_enumeration_warning=(enumerate_ == "parallel")))
Ejemplo n.º 8
0
def test_discrete_hmm_categorical(num_steps):
    state_dim = 3
    obs_dim = 4
    init_logits = torch.randn(state_dim)
    trans_logits = torch.randn(num_steps, state_dim, state_dim)
    obs_dist = dist.Categorical(
        logits=torch.randn(num_steps, state_dim, obs_dim))
    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    data = dist.Categorical(logits=torch.zeros(num_steps, obs_dim)).sample()
    actual = d.log_prob(data)
    assert actual.shape == d.batch_shape
    check_expand(d, data)

    # Check loss against TraceEnum_ELBO.
    @config_enumerate
    def model(data):
        x = pyro.sample("x_init", dist.Categorical(logits=init_logits))
        for t in range(num_steps):
            x = pyro.sample(
                "x_{}".format(t),
                dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :]))
            pyro.sample("obs_{}".format(t),
                        dist.Categorical(logits=Vindex(obs_dist.logits)[..., t,
                                                                        x, :]),
                        obs=data[..., t])

    expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data)
    actual_loss = -float(actual.sum())
    assert_close(actual_loss, expected_loss)
Ejemplo n.º 9
0
def test_discrete_hmm_diag_normal(num_steps):
    state_dim = 3
    event_size = 2
    init_logits = torch.randn(state_dim)
    trans_logits = torch.randn(num_steps, state_dim, state_dim)
    loc = torch.randn(num_steps, state_dim, event_size)
    scale = torch.randn(num_steps, state_dim, event_size).exp()
    obs_dist = dist.Normal(loc, scale).to_event(1)
    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    data = obs_dist.sample()[:, 0]
    actual = d.log_prob(data)
    assert actual.shape == d.batch_shape
    check_expand(d, data)

    # Check loss against TraceEnum_ELBO.
    @config_enumerate
    def model(data):
        x = pyro.sample("x_init", dist.Categorical(logits=init_logits))
        for t in range(num_steps):
            x = pyro.sample(
                "x_{}".format(t),
                dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :]))
            pyro.sample("obs_{}".format(t),
                        dist.Normal(
                            Vindex(loc)[..., t, x, :],
                            Vindex(scale)[..., t, x, :]).to_event(1),
                        obs=data[..., t, :])

    expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data)
    actual_loss = -float(actual.sum())
    assert_close(actual_loss, expected_loss)
Ejemplo n.º 10
0
def test_traceenum_elbo(length):
    hidden_dim = 10
    transition = pyro.param("transition",
                            0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim),
                            constraint=constraints.positive)
    means = pyro.param("means", torch.arange(float(hidden_dim)))
    data = 1 + 2 * torch.randn(length)

    @ignore_jit_warnings()
    def model(data):
        transition = pyro.param("transition")
        means = pyro.param("means")
        states = [torch.tensor(0)]
        for t in pyro.markov(range(len(data))):
            states.append(pyro.sample("states_{}".format(t),
                                      dist.Categorical(transition[states[-1]]),
                                      infer={"enumerate": "parallel"}))
            pyro.sample("obs_{}".format(t),
                        dist.Normal(means[states[-1]], 1.),
                        obs=data[t])
        return tuple(states)

    def guide(data):
        pass

    expected_loss = TraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data)
    actual_loss = JitTraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data)
    assert_equal(expected_loss, actual_loss)

    expected_grads = grad(expected_loss, [transition, means], allow_unused=True)
    actual_grads = grad(actual_loss, [transition, means], allow_unused=True)
    for e, a, name in zip(expected_grads, actual_grads, ["transition", "means"]):
        assert_equal(e, a, msg="bad gradient for {}".format(name))
Ejemplo n.º 11
0
def initialize(data):
    pyro.clear_param_store()

    optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO(max_iarange_nesting=1)
    svi = SVI(model, full_guide, optim, loss=elbo)

    # Initialize weights to uniform.
    pyro.param('auto_weights',
               0.5 * torch.ones(K),
               constraint=constraints.simplex)

    # Assume half of the data variance is due to intra-component noise.
    var = (data.var() / 2).sqrt()
    pyro.param('auto_scale',
               torch.tensor([var] * 4),
               constraint=constraints.positive)

    # Initialize means from a subsample of data.
    pyro.param('auto_locs',
               data[torch.multinomial(torch.ones(len(data)) / len(data), K)])

    loss = svi.loss(model, full_guide, data)

    return loss, svi
def main(model, guide, args):
    # init
    if args.seed is not None: pyro.set_rng_seed(args.seed)
    logger = get_logger(args.log, __name__)
    logger.info(args)

    # generate data
    args.num_docs = 1000
    args.batch_size = 32
    true_topic_weights, true_topic_words, data = generate_model(args=args)

    # setup svi
    pyro.clear_param_store()
    optim = Adam({'lr': args.learning_rate})
    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    svi = SVI(model.main, guide.main, optim, elbo)

    # train
    times = [time.time()]
    logger.info('\nstep\t' + 'epoch\t' + 'elbo\t' + 'time(sec)')

    for i in range(1, args.num_steps + 1):
        loss = svi.step(data, args=args, batch_size=args.batch_size)

        if (args.eval_frequency > 0
                and i % args.eval_frequency == 0) or (i == 1):
            times.append(time.time())
            logger.info(f'{i:06d}\t'
                        f'{(i * args.batch_size) / args.num_docs:.3f}\t'
                        f'{-loss:.4f}\t'
                        f'{times[-1]-times[-2]:.3f}')
Ejemplo n.º 13
0
 def _get_initial_trace():
     guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: not msg["name"].startswith("x") and
                                     not msg["name"].startswith("y")))
     elbo = TraceEnum_ELBO(max_plate_nesting=1)
     svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo)
     for _ in range(100):
         svi.step(data)
     return poutine.trace(guide).get_trace(data)
Ejemplo n.º 14
0
 def _get_initial_trace():
     guide = AutoDelta(
         poutine.block(model,
                       expose_fn=lambda msg: not msg["name"].startswith("x")
                       and not msg["name"].startswith("y")))
     elbo = TraceEnum_ELBO(max_plate_nesting=1)
     svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo,
               num_steps=100).run(data)
     return svi.exec_traces[-1]
Ejemplo n.º 15
0
def main(args):
    """
    run inference for CVAE
    :param args: arguments for CVAE
    :return: None
    """
    if args.seed is not None:
        set_seed(args.seed, args.cuda)

    if os.path.exists('cvae.model.pt'):
        print('Loading model %s' % 'cvae.model.pt')
        cvae = torch.load('cvae.model.pt')

    else:

        cvae = CVAE(z_dim=args.z_dim,
                    y_dim=8,
                    x_dim=32612,
                    hidden_dim=args.hidden_dimension,
                    use_cuda=args.cuda)

    print(cvae)

    # setup the optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta_1, 0.999),
        "clip_norm": 0.5
    }
    optimizer = ClippedAdam(adam_params)
    guide = config_enumerate(cvae.guide, args.enum_discrete)

    # set up the loss for inference.
    loss = SVI(cvae.model,
               guide,
               optimizer,
               loss=TraceEnum_ELBO(max_iarange_nesting=1))

    try:
        # setup the logger if a filename is provided
        logger = open(args.logfile, "w") if args.logfile else None

        data_loaders = setup_data_loaders(NHANES, args.cuda, args.batch_size)
        print(len(data_loaders['prediction']))

        #torch.save(cvae, 'cvae.model.pt')

        mu, sigma, actuals, lods, masks = get_predictions(
            data_loaders["prediction"], cvae.sim_measurements)

        torch.save((mu, sigma, actuals, lods, masks), 'cvae.predictions.pt')

    finally:
        # close the logger file object if we opened it earlier
        if args.logfile:
            logger.close()
Ejemplo n.º 16
0
def initialize(data):
    pyro.clear_param_store()
    optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    # global global_guide
    global_guide = AutoDelta(
        poutine.block(model, expose=['weights', 'mus', 'lambdas']))
    svi = SVI(model, global_guide, optim, loss=elbo)
    svi.loss(model, global_guide, data)
    return svi
Ejemplo n.º 17
0
def test_no_iarange_enum_discrete_batch_error():
    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p).expand_by([5]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p).expand_by([5]))

    assert_error(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 18
0
def test_enum_discrete_single_ok():
    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 19
0
def test_em(assignment_grad):
    args = make_args()
    args.assignment_grad = assignment_grad
    detections = generate_data(args)

    pyro.clear_param_store()
    pyro.param('noise_scale', torch.tensor(args.init_noise_scale),
               constraint=constraints.positive)
    pyro.param('objects_loc', torch.randn(args.max_num_objects, 1))

    # Learn object_loc via EM algorithm.
    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    newton = Newton(trust_radii={'objects_loc': 1.0})
    for step in range(10):
        # Detach previous iterations.
        objects_loc = pyro.param('objects_loc').detach_().requires_grad_()
        loss = elbo.differentiable_loss(model, guide, detections, args)  # E-step
        newton.step(loss, {'objects_loc': objects_loc})  # M-step
        logger.debug('step {}, loss = {}'.format(step, loss.item()))
Ejemplo n.º 20
0
def test_enum_discrete_iranges_iarange_dependency_warning(
        enumerate_, is_validate):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        inner_iarange = pyro.iarange("iarange", 10, 5)

        for i in pyro.irange("irange1", 2):
            with inner_iarange:
                pyro.sample("x_{}".format(i),
                            dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})

        for i in pyro.irange("irange2", 2):
            pyro.sample("y_{}".format(i), dist.Bernoulli(0.5))

    with pyro.validation_enabled(is_validate):
        if enumerate_ and is_validate:
            assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
        else:
            assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Ejemplo n.º 21
0
 def update_posterior(self, X, y):
     X = torch.cat([self.gpmodel.X, X])
     y = torch.cat([self.gpmodel.y, y])
     self.gpmodel.set_data(X, y)
     optimizer = torch.optim.Adam(self.gpmodel.parameters(), lr=0.001)
     gp.util.train(
         self.gpmodel,
         optimizer,
         loss_fn=TraceEnum_ELBO(
             strict_enumeration_warning=False).differentiable_loss,
         retain_graph=True)
Ejemplo n.º 22
0
def test_discrete_parallel(continuous_class):
    K = 2
    data = torch.tensor([0., 1., 10., 11., 12.])

    def model(data):
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).independent(1))
        scale = pyro.sample('scale', dist.LogNormal(0, 1))

        with pyro.iarange('data', len(data)):
            weights = weights.expand(torch.Size((len(data),)) + weights.shape)
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    guide = AutoGuideList(model)
    guide.add(continuous_class(poutine.block(model, hide=["assignment"])))
    guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))

    elbo = TraceEnum_ELBO(max_iarange_nesting=1)
    loss = elbo.loss_and_grads(model, guide, data)
    assert np.isfinite(loss), loss
Ejemplo n.º 23
0
def test_discrete_parallel(continuous_class):
    K = 2
    data = torch.tensor([0., 1., 10., 11., 12.])

    def model(data):
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1))
        scale = pyro.sample('scale', dist.LogNormal(0, 1))

        with pyro.plate('data', len(data)):
            weights = weights.expand(torch.Size((len(data),)) + weights.shape)
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    guide = AutoGuideList(model)
    guide.append(continuous_class(poutine.block(model, hide=["assignment"])))
    guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))

    elbo = TraceEnum_ELBO(max_plate_nesting=1)
    loss = elbo.loss_and_grads(model, guide, data)
    assert np.isfinite(loss), loss
Ejemplo n.º 24
0
def test_enum_discrete_irange_single_ok():
    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 10, 5):
            pyro.sample("x_{}".format(i), dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 10, 5):
            pyro.sample("x_{}".format(i), dist.Bernoulli(p))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 25
0
def test_iarange_enum_discrete_batch_ok():
    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 26
0
def test_enum_discrete_irange_iarange_dependency_ok(enumerate_):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        inner_iarange = pyro.iarange("iarange", 10, 5)
        for i in pyro.irange("irange", 3):
            pyro.sample("y_{}".format(i), dist.Bernoulli(0.5))
            with inner_iarange:
                pyro.sample("x_{}".format(i),
                            dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})

    assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Ejemplo n.º 27
0
def test_enum_discrete_parallel_nested_ok(max_iarange_nesting):
    iarange_shape = torch.Size([1] * max_iarange_nesting)

    def model():
        p2 = torch.tensor(torch.ones(2) / 2)
        p3 = torch.tensor(torch.ones(3) / 3)
        x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
        x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
        assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
        assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape

    assert_ok(model, config_enumerate(model, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
Ejemplo n.º 28
0
def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2):
    pyro.clear_param_store()
    num_particles = 10
    q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval)
    p = 0.2693204236205713  # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5

    def model():
        pyro.sample("x", dist.Bernoulli(p))
        for i in pyro.irange("irange", irange_dim):
            pyro.sample("y_{}".format(i), dist.Bernoulli(p))

    def guide():
        q = pyro.param("q")
        pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1})
        for i in pyro.irange("irange", irange_dim):
            pyro.sample("y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2})

    kl = (1 + irange_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p))
    expected_loss = kl.item()
    expected_grad = grad(kl, [q.unconstrained()])[0]

    inner_particles = 2
    outer_particles = num_particles // inner_particles
    elbo = TraceEnum_ELBO(max_iarange_nesting=0,
                          strict_enumeration_warning=any([enumerate1, enumerate2]),
                          num_particles=inner_particles)
    actual_loss = sum(elbo.loss_and_grads(model, guide)
                      for i in range(outer_particles)) / outer_particles
    actual_grad = q.unconstrained().grad / outer_particles

    assert_equal(actual_loss, expected_loss, prec=0.3, msg="".join([
        "\nexpected loss = {}".format(expected_loss),
        "\n  actual loss = {}".format(actual_loss),
    ]))
    assert_equal(actual_grad, expected_grad, prec=0.5, msg="".join([
        "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()),
        "\n  actual grad = {}".format(actual_grad.detach().cpu().numpy()),
    ]))
Ejemplo n.º 29
0
def test_enum_discrete_parallel_ok(max_iarange_nesting):
    iarange_shape = torch.Size([1] * max_iarange_nesting)

    def model():
        p = torch.tensor(0.5)
        x = pyro.sample("x", dist.Bernoulli(p))
        assert x.shape == torch.Size([2]) + iarange_shape

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        x = pyro.sample("x", dist.Bernoulli(p))
        assert x.shape == torch.Size([2]) + iarange_shape

    assert_ok(model, config_enumerate(guide, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
Ejemplo n.º 30
0
def test_enum_discrete_missing_config_warning(strict_enumeration_warning):
    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p))

    elbo = TraceEnum_ELBO(
        strict_enumeration_warning=strict_enumeration_warning)
    if strict_enumeration_warning:
        assert_warning(model, guide, elbo)
    else:
        assert_ok(model, guide, elbo)
Ejemplo n.º 31
0
def test_enum_discrete_iaranges_dependency_ok(enumerate_):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        x_iarange = pyro.iarange("x_iarange", 10, 5, dim=-1)
        y_iarange = pyro.iarange("y_iarange", 11, 6, dim=-2)
        pyro.sample("a", dist.Bernoulli(0.5))
        with x_iarange:
            pyro.sample("b", dist.Bernoulli(0.5).expand_by([5]))
        with y_iarange:
            # Note that it is difficult to check that c does not depend on b.
            pyro.sample("c", dist.Bernoulli(0.5).expand_by([6, 1]))
        with x_iarange, y_iarange:
            pyro.sample("d", dist.Bernoulli(0.5).expand_by([6, 5]))

    assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=2))
Ejemplo n.º 32
0
def aic_num_parameters(model, guide=None):
    """
    hacky AIC param count that includes all parameters in the model and guide
    """
    def _size(tensor):
        """product of shape"""
        s = 1
        for d in tensor.shape:
            s = s * d
        return s

    with poutine.block(), poutine.trace(param_only=True) as param_capture:
        TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss(model, guide)

    return sum(
        _size(node["value"]) for node in param_capture.trace.nodes.values())