def test_em_nested_in_svi(assignment_grad): args = make_args() args.assignment_grad = assignment_grad detections = generate_data(args) pyro.clear_param_store() pyro.param('noise_scale', torch.tensor(args.init_noise_scale), constraint=constraints.positive) pyro.param('objects_loc', torch.randn(args.max_num_objects, 1)) # Learn object_loc via EM and noise_scale via SVI. optim = Adam({'lr': 0.1}) elbo = TraceEnum_ELBO(max_plate_nesting=2) newton = Newton(trust_radii={'objects_loc': 1.0}) svi = SVI(poutine.block(model, hide=['objects_loc']), poutine.block(guide, hide=['objects_loc']), optim, elbo) for svi_step in range(50): for em_step in range(2): objects_loc = pyro.param('objects_loc').detach_().requires_grad_() assert pyro.param('objects_loc').grad_fn is None loss = elbo.differentiable_loss(model, guide, detections, args) # E-step updated = newton.get_step(loss, {'objects_loc': objects_loc}) # M-step assert updated['objects_loc'].grad_fn is not None pyro.get_param_store()['objects_loc'] = updated['objects_loc'] assert pyro.param('objects_loc').grad_fn is not None loss = svi.step(detections, args) logger.debug( 'step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format( svi_step, loss, pyro.param('noise_scale').item()))
def test_prob(nderivs): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 data = torch.tensor([0.5, 1., 1.5]) p = pyro.param("p", torch.tensor(0.25)) @config_enumerate def model(num_particles): p = pyro.param("p") with pyro.plate("num_particles", num_particles, dim=-2): z = pyro.sample("z", dist.Bernoulli(p)) with pyro.plate("data", 3): pyro.sample("x", dist.Normal(z, 1.), obs=data) def guide(num_particles): pass elbo = TraceEnum_ELBO(max_plate_nesting=2) expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1) posterior_model = infer_discrete(config_enumerate(model, "parallel"), first_available_dim=-3) posterior_trace = poutine.trace(posterior_model).get_trace( num_particles=num_particles) actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2) if nderivs == 0: assert_equal(expected_logprob, actual_logprob, prec=1e-3) elif nderivs == 1: expected_grad = grad(expected_logprob, [p])[0] actual_grad = grad(actual_logprob, [p])[0] assert_equal(expected_grad, actual_grad, prec=1e-3)
def test_enum_discrete_parallel_iarange_ok(): enum_discrete = "defined below" def model(): p2 = torch.ones(2) / 2 p34 = torch.ones(3, 4) / 4 p536 = torch.ones(5, 3, 6) / 6 x2 = pyro.sample("x2", dist.Categorical(p2)) with pyro.iarange("outer", 3): x34 = pyro.sample("x34", dist.Categorical(p34)) with pyro.iarange("inner", 5): x536 = pyro.sample("x536", dist.Categorical(p536)) if enum_discrete == "sequential": # All dimensions are iarange dimensions. assert x2.shape == torch.Size([]) assert x34.shape == torch.Size([3]) assert x536.shape == torch.Size([5, 3]) else: # Meaning of dimensions: [ enum dims | iarange dims ] assert x2.shape == torch.Size([2, 1, 1]) # noqa: E201 assert x34.shape == torch.Size([4, 1, 1, 3]) # noqa: E201 assert x536.shape == torch.Size([6, 1, 1, 5, 3]) # noqa: E201 enum_discrete = "sequential" assert_ok(model, config_enumerate(model, "sequential"), TraceEnum_ELBO(max_iarange_nesting=2)) enum_discrete = "parallel" assert_ok(model, config_enumerate(model, "parallel"), TraceEnum_ELBO(max_iarange_nesting=2))
def test_svi_multi(): args = make_args() args.assignment_grad = True detections = generate_data(args) pyro.clear_param_store() pyro.param('noise_scale', torch.tensor(args.init_noise_scale), constraint=constraints.positive) pyro.param('objects_loc', torch.randn(args.max_num_objects, 1)) # Learn object_loc via Newton and noise_scale via Adam. elbo = TraceEnum_ELBO(max_plate_nesting=2) adam = Adam({'lr': 0.1}) newton = Newton(trust_radii={'objects_loc': 1.0}) optim = MixedMultiOptimizer([(['noise_scale'], adam), (['objects_loc'], newton)]) for svi_step in range(50): with poutine.trace(param_only=True) as param_capture: loss = elbo.differentiable_loss(model, guide, detections, args) params = { name: pyro.param(name).unconstrained() for name in param_capture.trace.nodes.keys() } optim.step(loss, params) logger.debug( 'step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format( svi_step, loss.item(), pyro.param('noise_scale').item()))
def test_svi_enum(plate_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 q = pyro.param("q", constant(0.75), constraint=constraints.unit_interval) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): pyro.sample("x", dist.Bernoulli(p)) for i in pyro.plate("plate", plate_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(p)) def guide(): q = pyro.param("q") pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1}) for i in pyro.plate("plate", plate_dim): pyro.sample( "y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2} ) kl = (1 + plate_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() expected_grad = grad(kl, [q.unconstrained()])[0] inner_particles = 2 outer_particles = num_particles // inner_particles elbo = TraceEnum_ELBO( max_plate_nesting=0, strict_enumeration_warning=any([enumerate1, enumerate2]), num_particles=inner_particles, ignore_jit_warnings=True, ) actual_loss = ( sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles ) actual_grad = q.unconstrained().grad / outer_particles assert_equal( actual_loss, expected_loss, prec=0.3, msg="".join( [ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ] ), ) assert_equal( actual_grad, expected_grad, prec=0.5, msg="".join( [ "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), ] ), )
def test_enum_discrete_iarange_dependency_warning(enumerate_, is_validate): def model(): pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) with pyro.iarange("iarange", 10, 5): x = pyro.sample("x", dist.Bernoulli(0.5).expand_by([5]), infer={'enumerate': enumerate_}) pyro.sample("y", dist.Bernoulli(x.mean())) # user should move this line up with pyro.validation_enabled(is_validate): if enumerate_ and is_validate: assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1)) else: assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
def test_enum_discrete_iarange_shape_broadcasting_ok(enumerate_): @poutine.broadcast @config_enumerate(default=enumerate_) def model(): x_iarange = pyro.iarange("x_iarange", 10, 5, dim=-1) y_iarange = pyro.iarange("y_iarange", 11, 6, dim=-2) with pyro.iarange("num_particles", 50, dim=-3): with x_iarange: b = pyro.sample( "b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) assert b.shape == torch.Size((50, 1, 5)) with y_iarange: c = pyro.sample("c", dist.Bernoulli(0.5)) if enumerate_ == "parallel": assert c.shape == torch.Size((2, 50, 6, 1)) else: assert c.shape == torch.Size((50, 6, 1)) with x_iarange, y_iarange: d = pyro.sample("d", dist.Bernoulli(b)) if enumerate_ == "parallel": assert d.shape == torch.Size((2, 1, 50, 6, 5)) else: assert d.shape == torch.Size((50, 6, 5)) assert_ok( model, model, TraceEnum_ELBO(max_iarange_nesting=3, strict_enumeration_warning=(enumerate_ == "parallel")))
def test_discrete_hmm_categorical(num_steps): state_dim = 3 obs_dim = 4 init_logits = torch.randn(state_dim) trans_logits = torch.randn(num_steps, state_dim, state_dim) obs_dist = dist.Categorical( logits=torch.randn(num_steps, state_dim, obs_dim)) d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) data = dist.Categorical(logits=torch.zeros(num_steps, obs_dim)).sample() actual = d.log_prob(data) assert actual.shape == d.batch_shape check_expand(d, data) # Check loss against TraceEnum_ELBO. @config_enumerate def model(data): x = pyro.sample("x_init", dist.Categorical(logits=init_logits)) for t in range(num_steps): x = pyro.sample( "x_{}".format(t), dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :])) pyro.sample("obs_{}".format(t), dist.Categorical(logits=Vindex(obs_dist.logits)[..., t, x, :]), obs=data[..., t]) expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data) actual_loss = -float(actual.sum()) assert_close(actual_loss, expected_loss)
def test_discrete_hmm_diag_normal(num_steps): state_dim = 3 event_size = 2 init_logits = torch.randn(state_dim) trans_logits = torch.randn(num_steps, state_dim, state_dim) loc = torch.randn(num_steps, state_dim, event_size) scale = torch.randn(num_steps, state_dim, event_size).exp() obs_dist = dist.Normal(loc, scale).to_event(1) d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) data = obs_dist.sample()[:, 0] actual = d.log_prob(data) assert actual.shape == d.batch_shape check_expand(d, data) # Check loss against TraceEnum_ELBO. @config_enumerate def model(data): x = pyro.sample("x_init", dist.Categorical(logits=init_logits)) for t in range(num_steps): x = pyro.sample( "x_{}".format(t), dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :])) pyro.sample("obs_{}".format(t), dist.Normal( Vindex(loc)[..., t, x, :], Vindex(scale)[..., t, x, :]).to_event(1), obs=data[..., t, :]) expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data) actual_loss = -float(actual.sum()) assert_close(actual_loss, expected_loss)
def test_traceenum_elbo(length): hidden_dim = 10 transition = pyro.param("transition", 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim), constraint=constraints.positive) means = pyro.param("means", torch.arange(float(hidden_dim))) data = 1 + 2 * torch.randn(length) @ignore_jit_warnings() def model(data): transition = pyro.param("transition") means = pyro.param("means") states = [torch.tensor(0)] for t in pyro.markov(range(len(data))): states.append(pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]), infer={"enumerate": "parallel"})) pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return tuple(states) def guide(data): pass expected_loss = TraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data) actual_loss = JitTraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data) assert_equal(expected_loss, actual_loss) expected_grads = grad(expected_loss, [transition, means], allow_unused=True) actual_grads = grad(actual_loss, [transition, means], allow_unused=True) for e, a, name in zip(expected_grads, actual_grads, ["transition", "means"]): assert_equal(e, a, msg="bad gradient for {}".format(name))
def initialize(data): pyro.clear_param_store() optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]}) elbo = TraceEnum_ELBO(max_iarange_nesting=1) svi = SVI(model, full_guide, optim, loss=elbo) # Initialize weights to uniform. pyro.param('auto_weights', 0.5 * torch.ones(K), constraint=constraints.simplex) # Assume half of the data variance is due to intra-component noise. var = (data.var() / 2).sqrt() pyro.param('auto_scale', torch.tensor([var] * 4), constraint=constraints.positive) # Initialize means from a subsample of data. pyro.param('auto_locs', data[torch.multinomial(torch.ones(len(data)) / len(data), K)]) loss = svi.loss(model, full_guide, data) return loss, svi
def main(model, guide, args): # init if args.seed is not None: pyro.set_rng_seed(args.seed) logger = get_logger(args.log, __name__) logger.info(args) # generate data args.num_docs = 1000 args.batch_size = 32 true_topic_weights, true_topic_words, data = generate_model(args=args) # setup svi pyro.clear_param_store() optim = Adam({'lr': args.learning_rate}) elbo = TraceEnum_ELBO(max_plate_nesting=2) svi = SVI(model.main, guide.main, optim, elbo) # train times = [time.time()] logger.info('\nstep\t' + 'epoch\t' + 'elbo\t' + 'time(sec)') for i in range(1, args.num_steps + 1): loss = svi.step(data, args=args, batch_size=args.batch_size) if (args.eval_frequency > 0 and i % args.eval_frequency == 0) or (i == 1): times.append(time.time()) logger.info(f'{i:06d}\t' f'{(i * args.batch_size) / args.num_docs:.3f}\t' f'{-loss:.4f}\t' f'{times[-1]-times[-2]:.3f}')
def _get_initial_trace(): guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: not msg["name"].startswith("x") and not msg["name"].startswith("y"))) elbo = TraceEnum_ELBO(max_plate_nesting=1) svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo) for _ in range(100): svi.step(data) return poutine.trace(guide).get_trace(data)
def _get_initial_trace(): guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: not msg["name"].startswith("x") and not msg["name"].startswith("y"))) elbo = TraceEnum_ELBO(max_plate_nesting=1) svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo, num_steps=100).run(data) return svi.exec_traces[-1]
def main(args): """ run inference for CVAE :param args: arguments for CVAE :return: None """ if args.seed is not None: set_seed(args.seed, args.cuda) if os.path.exists('cvae.model.pt'): print('Loading model %s' % 'cvae.model.pt') cvae = torch.load('cvae.model.pt') else: cvae = CVAE(z_dim=args.z_dim, y_dim=8, x_dim=32612, hidden_dim=args.hidden_dimension, use_cuda=args.cuda) print(cvae) # setup the optimizer adam_params = { "lr": args.learning_rate, "betas": (args.beta_1, 0.999), "clip_norm": 0.5 } optimizer = ClippedAdam(adam_params) guide = config_enumerate(cvae.guide, args.enum_discrete) # set up the loss for inference. loss = SVI(cvae.model, guide, optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1)) try: # setup the logger if a filename is provided logger = open(args.logfile, "w") if args.logfile else None data_loaders = setup_data_loaders(NHANES, args.cuda, args.batch_size) print(len(data_loaders['prediction'])) #torch.save(cvae, 'cvae.model.pt') mu, sigma, actuals, lods, masks = get_predictions( data_loaders["prediction"], cvae.sim_measurements) torch.save((mu, sigma, actuals, lods, masks), 'cvae.predictions.pt') finally: # close the logger file object if we opened it earlier if args.logfile: logger.close()
def initialize(data): pyro.clear_param_store() optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]}) elbo = TraceEnum_ELBO(max_plate_nesting=2) # global global_guide global_guide = AutoDelta( poutine.block(model, expose=['weights', 'mus', 'lambdas'])) svi = SVI(model, global_guide, optim, loss=elbo) svi.loss(model, global_guide, data) return svi
def test_no_iarange_enum_discrete_batch_error(): def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p).expand_by([5])) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) pyro.sample("x", dist.Bernoulli(p).expand_by([5])) assert_error(model, config_enumerate(guide), TraceEnum_ELBO())
def test_enum_discrete_single_ok(): def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) pyro.sample("x", dist.Bernoulli(p)) assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
def test_em(assignment_grad): args = make_args() args.assignment_grad = assignment_grad detections = generate_data(args) pyro.clear_param_store() pyro.param('noise_scale', torch.tensor(args.init_noise_scale), constraint=constraints.positive) pyro.param('objects_loc', torch.randn(args.max_num_objects, 1)) # Learn object_loc via EM algorithm. elbo = TraceEnum_ELBO(max_plate_nesting=2) newton = Newton(trust_radii={'objects_loc': 1.0}) for step in range(10): # Detach previous iterations. objects_loc = pyro.param('objects_loc').detach_().requires_grad_() loss = elbo.differentiable_loss(model, guide, detections, args) # E-step newton.step(loss, {'objects_loc': objects_loc}) # M-step logger.debug('step {}, loss = {}'.format(step, loss.item()))
def test_enum_discrete_iranges_iarange_dependency_warning( enumerate_, is_validate): def model(): pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) inner_iarange = pyro.iarange("iarange", 10, 5) for i in pyro.irange("irange1", 2): with inner_iarange: pyro.sample("x_{}".format(i), dist.Bernoulli(0.5).expand_by([5]), infer={'enumerate': enumerate_}) for i in pyro.irange("irange2", 2): pyro.sample("y_{}".format(i), dist.Bernoulli(0.5)) with pyro.validation_enabled(is_validate): if enumerate_ and is_validate: assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1)) else: assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
def update_posterior(self, X, y): X = torch.cat([self.gpmodel.X, X]) y = torch.cat([self.gpmodel.y, y]) self.gpmodel.set_data(X, y) optimizer = torch.optim.Adam(self.gpmodel.parameters(), lr=0.001) gp.util.train( self.gpmodel, optimizer, loss_fn=TraceEnum_ELBO( strict_enumeration_warning=False).differentiable_loss, retain_graph=True)
def test_discrete_parallel(continuous_class): K = 2 data = torch.tensor([0., 1., 10., 11., 12.]) def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).independent(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.iarange('data', len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.add(continuous_class(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) elbo = TraceEnum_ELBO(max_iarange_nesting=1) loss = elbo.loss_and_grads(model, guide, data) assert np.isfinite(loss), loss
def test_discrete_parallel(continuous_class): K = 2 data = torch.tensor([0., 1., 10., 11., 12.]) def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.plate('data', len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.append(continuous_class(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) elbo = TraceEnum_ELBO(max_plate_nesting=1) loss = elbo.loss_and_grads(model, guide, data) assert np.isfinite(loss), loss
def test_enum_discrete_irange_single_ok(): def model(): p = torch.tensor(0.5) for i in pyro.irange("irange", 10, 5): pyro.sample("x_{}".format(i), dist.Bernoulli(p)) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) for i in pyro.irange("irange", 10, 5): pyro.sample("x_{}".format(i), dist.Bernoulli(p)) assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
def test_iarange_enum_discrete_batch_ok(): def model(): p = torch.tensor(0.5) with pyro.iarange("iarange", 10, 5) as ind: pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)])) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) with pyro.iarange("iarange", 10, 5) as ind: pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)])) assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
def test_enum_discrete_irange_iarange_dependency_ok(enumerate_): def model(): pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) inner_iarange = pyro.iarange("iarange", 10, 5) for i in pyro.irange("irange", 3): pyro.sample("y_{}".format(i), dist.Bernoulli(0.5)) with inner_iarange: pyro.sample("x_{}".format(i), dist.Bernoulli(0.5).expand_by([5]), infer={'enumerate': enumerate_}) assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
def test_enum_discrete_parallel_nested_ok(max_iarange_nesting): iarange_shape = torch.Size([1] * max_iarange_nesting) def model(): p2 = torch.tensor(torch.ones(2) / 2) p3 = torch.tensor(torch.ones(3) / 3) x2 = pyro.sample("x2", dist.OneHotCategorical(p2)) x3 = pyro.sample("x3", dist.OneHotCategorical(p3)) assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape assert_ok(model, config_enumerate(model, "parallel"), TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): pyro.sample("x", dist.Bernoulli(p)) for i in pyro.irange("irange", irange_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(p)) def guide(): q = pyro.param("q") pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1}) for i in pyro.irange("irange", irange_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2}) kl = (1 + irange_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() expected_grad = grad(kl, [q.unconstrained()])[0] inner_particles = 2 outer_particles = num_particles // inner_particles elbo = TraceEnum_ELBO(max_iarange_nesting=0, strict_enumeration_warning=any([enumerate1, enumerate2]), num_particles=inner_particles) actual_loss = sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles actual_grad = q.unconstrained().grad / outer_particles assert_equal(actual_loss, expected_loss, prec=0.3, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) assert_equal(actual_grad, expected_grad, prec=0.5, msg="".join([ "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), ]))
def test_enum_discrete_parallel_ok(max_iarange_nesting): iarange_shape = torch.Size([1] * max_iarange_nesting) def model(): p = torch.tensor(0.5) x = pyro.sample("x", dist.Bernoulli(p)) assert x.shape == torch.Size([2]) + iarange_shape def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) x = pyro.sample("x", dist.Bernoulli(p)) assert x.shape == torch.Size([2]) + iarange_shape assert_ok(model, config_enumerate(guide, "parallel"), TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
def test_enum_discrete_missing_config_warning(strict_enumeration_warning): def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) pyro.sample("x", dist.Bernoulli(p)) elbo = TraceEnum_ELBO( strict_enumeration_warning=strict_enumeration_warning) if strict_enumeration_warning: assert_warning(model, guide, elbo) else: assert_ok(model, guide, elbo)
def test_enum_discrete_iaranges_dependency_ok(enumerate_): def model(): pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) x_iarange = pyro.iarange("x_iarange", 10, 5, dim=-1) y_iarange = pyro.iarange("y_iarange", 11, 6, dim=-2) pyro.sample("a", dist.Bernoulli(0.5)) with x_iarange: pyro.sample("b", dist.Bernoulli(0.5).expand_by([5])) with y_iarange: # Note that it is difficult to check that c does not depend on b. pyro.sample("c", dist.Bernoulli(0.5).expand_by([6, 1])) with x_iarange, y_iarange: pyro.sample("d", dist.Bernoulli(0.5).expand_by([6, 5])) assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=2))
def aic_num_parameters(model, guide=None): """ hacky AIC param count that includes all parameters in the model and guide """ def _size(tensor): """product of shape""" s = 1 for d in tensor.shape: s = s * d return s with poutine.block(), poutine.trace(param_only=True) as param_capture: TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss(model, guide) return sum( _size(node["value"]) for node in param_capture.trace.nodes.values())