def test_hmc_conjugate_gaussian(fixture, num_samples, warmup_steps, hmc_params, expected_means, expected_precs, mean_tol, std_tol): pyro.get_param_store().clear() hmc_kernel = HMC(fixture.model, **hmc_params) mcmc_run = MCMC(hmc_kernel, num_samples, warmup_steps).run(fixture.data) for i in range(1, fixture.chain_len + 1): param_name = 'loc_' + str(i) marginal = EmpiricalMarginal(mcmc_run, sites=param_name) latent_loc = marginal.mean latent_std = marginal.variance.sqrt() expected_mean = torch.ones(fixture.dim) * expected_means[i - 1] expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1]) # Actual vs expected posterior means for the latents logger.info('Posterior mean (actual) - {}'.format(param_name)) logger.info(latent_loc) logger.info('Posterior mean (expected) - {}'.format(param_name)) logger.info(expected_mean) assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) # Actual vs expected posterior precisions for the latents logger.info('Posterior std (actual) - {}'.format(param_name)) logger.info(latent_std) logger.info('Posterior std (expected) - {}'.format(param_name)) logger.info(expected_std) assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
def test_categorical_gradient_with_logits(init_tensor_type): p = Variable(init_tensor_type([-float('inf'), 0]), requires_grad=True) categorical = Categorical(logits=p) log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1]))) log_pdf.sum().backward() assert_equal(log_pdf.data[0], 0) assert_equal(p.grad.data[0], 0)
def test_decorator_interface_primitives(): @poutine.trace def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) pyro.sample("a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"}) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) tr = model.get_trace() assert isinstance(tr, poutine.Trace) assert tr.graph_type == "flat" @poutine.trace(graph_type="dense") def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) pyro.sample("a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"}) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) tr = model.get_trace() assert isinstance(tr, poutine.Trace) assert tr.graph_type == "dense" tr2 = poutine.trace(poutine.replay(model, trace=tr)).get_trace() assert_equal(tr2.nodes["a"]["value"], tr.nodes["a"]["value"])
def test_mean_and_var(self): torch_samples = [dist.Delta(self.v).sample().detach().cpu().numpy() for _ in range(self.n_samples)] torch_mean = np.mean(torch_samples) torch_var = np.var(torch_samples) assert_equal(torch_mean, self.analytic_mean) assert_equal(torch_var, self.analytic_var)
def test_batch_log_dims(dim, vs, one_hot, ps): batch_pdf_shape = (3,) + (1,) * dim expected_log_pdf = np.array(wrap_nested(list(np.log(ps)), dim-1)).reshape(*batch_pdf_shape) ps, vs = modify_params_using_dims(ps, vs, dim) support = dist.categorical.enumerate_support(ps, vs, one_hot=one_hot) batch_log_pdf = dist.categorical.batch_log_pdf(support, ps, vs, one_hot=one_hot) assert_equal(batch_log_pdf.data.cpu().numpy(), expected_log_pdf)
def test_bern_elbo_gradient(enum_discrete, trace_graph): pyro.clear_param_store() num_particles = 2000 def model(): p = Variable(torch.Tensor([0.25])) pyro.sample("z", dist.Bernoulli(p)) def guide(): p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True)) pyro.sample("z", dist.Bernoulli(p)) print("Computing gradients using surrogate loss") Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO elbo = Elbo(enum_discrete=enum_discrete, num_particles=(1 if enum_discrete else num_particles)) with xfail_if_not_implemented(): elbo.loss_and_grads(model, guide) params = sorted(pyro.get_param_store().get_all_param_names()) assert params, "no params found" actual_grads = {name: pyro.param(name).grad.clone() for name in params} print("Computing gradients using finite difference") elbo = Trace_ELBO(num_particles=num_particles) expected_grads = finite_difference(lambda: elbo.loss(model, guide)) for name in params: print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data, expected_grads[name].data)) assert_equal(actual_grads, expected_grads, prec=0.1)
def test_optimizers(factory): optim = factory() def model(loc, cov): x = pyro.param("x", torch.randn(2)) y = pyro.param("y", torch.randn(3, 2)) z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1)) pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x) with pyro.iarange("y_iarange", 3): pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y) with pyro.iarange("z_iarange", 4): pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z) loc = torch.tensor([-0.5, 0.5]) cov = torch.tensor([[1.0, 0.09], [0.09, 0.1]]) for step in range(100): tr = poutine.trace(model).get_trace(loc, cov) loss = -tr.log_prob_sum() params = {name: pyro.param(name).unconstrained() for name in ["x", "y", "z"]} optim.step(loss, params) for name in ["x", "y", "z"]: actual = pyro.param(name) expected = loc.expand(actual.shape) assert_equal(actual, expected, prec=1e-2, msg='{} in correct: {} vs {}'.format(name, actual, expected))
def test_quantiles(auto_class, Elbo): def model(): pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.LogNormal(0.0, 1.0)) pyro.sample("z", dist.Beta(2.0, 2.0)) guide = auto_class(model) infer = SVI(model, guide, Adam({'lr': 0.01}), Elbo(strict_enumeration_warning=False)) for _ in range(100): infer.step() quantiles = guide.quantiles([0.1, 0.5, 0.9]) median = guide.median() for name in ["x", "y", "z"]: assert_equal(median[name], quantiles[name][1]) quantiles = {name: [v.item() for v in value] for name, value in quantiles.items()} assert -3.0 < quantiles["x"][0] assert quantiles["x"][0] + 1.0 < quantiles["x"][1] assert quantiles["x"][1] + 1.0 < quantiles["x"][2] assert quantiles["x"][2] < 3.0 assert 0.01 < quantiles["y"][0] assert quantiles["y"][0] * 2.0 < quantiles["y"][1] assert quantiles["y"][1] * 2.0 < quantiles["y"][2] assert quantiles["y"][2] < 100.0 assert 0.01 < quantiles["z"][0] assert quantiles["z"][0] + 0.1 < quantiles["z"][1] assert quantiles["z"][1] + 0.1 < quantiles["z"][2] assert quantiles["z"][2] < 0.99
def test_iter_discrete_traces_vector(graph_type): pyro.clear_param_store() def model(): p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]]))) ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]))) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(ps, one_hot=False)) assert x.size() == (2, 1) assert y.size() == (2, 1) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) p = pyro.param("p").data ps = pyro.param("ps").data assert len(traces) == 2 * ps.size(-1) for scale, trace in traces: x = trace.nodes["x"]["value"].data.squeeze().long()[0] y = trace.nodes["y"]["value"].data.squeeze().long()[0] expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) * dist.Categorical(ps, one_hot=False).log_pdf(y)) expected_scale = expected_scale.data.view(-1)[0] assert_equal(scale, expected_scale)
def test_compute_downstream_costs_iarange_reuse(dim1, dim2): guide_trace = poutine.trace(iarange_reuse_model_guide, graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2) model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute for k in dc: assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k]) expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'] expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] expected_c1 += model_trace.nodes['obs']['log_prob'] assert_equal(expected_c1, dc['c1'])
def test_mask(batch_dim, event_dim, mask_dim): # Construct base distribution. shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim]) batch_shape = shape[:batch_dim] mask_shape = batch_shape[batch_dim - mask_dim:] base_dist = Bernoulli(0.1).expand_by(shape).independent(event_dim) # Construct masked distribution. mask = checker_mask(mask_shape) dist = base_dist.mask(mask) # Check shape. sample = base_dist.sample() assert dist.batch_shape == base_dist.batch_shape assert dist.event_shape == base_dist.event_shape assert sample.shape == sample.shape assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape # Check values. assert_equal(dist.mean, base_dist.mean) assert_equal(dist.variance, base_dist.variance) assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask) assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0) if not dist.event_shape: assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
def test_bernoulli_with_logits_overflow_gradient(init_tensor_type): p = Variable(init_tensor_type([1e40]), requires_grad=True) bernoulli = Bernoulli(logits=p) log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([1]))) log_pdf.sum().backward() assert_equal(log_pdf.data[0], 0) assert_equal(p.grad.data[0], 0)
def test_bernoulli_underflow_gradient(init_tensor_type): p = Variable(init_tensor_type([0]), requires_grad=True) bernoulli = Bernoulli(sigmoid(p) * 0.0) log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([0]))) log_pdf.sum().backward() assert_equal(log_pdf.data[0], 0) assert_equal(p.grad.data[0], 0)
def test_unweighted_samples(batch_shape, sample_shape, dtype): empirical_dist = Empirical() for i in range(5): empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i) samples = empirical_dist.sample(sample_shape=sample_shape) assert_equal(samples.size(), sample_shape + batch_shape) assert_equal(set(samples.view(-1).tolist()), set(range(5)))
def test_elbo_bern(quantity, enumerate1): pyro.clear_param_store() num_particles = 1 if enumerate1 else 10000 prec = 0.001 if enumerate1 else 0.1 q = pyro.param("q", torch.tensor(0.5, requires_grad=True)) kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25)) def model(): with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles])) @config_enumerate(default=enumerate1) def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles])) elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) if quantity == "loss": actual = elbo.loss(model, guide) / num_particles expected = kl.item() assert_equal(actual, expected, prec=prec, msg="".join([ "\nexpected = {}".format(expected), "\n actual = {}".format(actual), ])) else: elbo.loss_and_grads(model, guide) actual = q.grad / num_particles expected = grad(kl, [q])[0] assert_equal(actual, expected, prec=prec, msg="".join([ "\nexpected = {}".format(expected.detach().cpu().numpy()), "\n actual = {}".format(actual.detach().cpu().numpy()), ]))
def test_elbo_hmm_in_guide(enumerate1, num_steps): pyro.clear_param_store() data = torch.ones(num_steps) init_probs = torch.tensor([0.5, 0.5]) def model(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) emission_probs = pyro.param("emission_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y) @config_enumerate(default=enumerate1) def guide(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) elbo = TraceEnum_ELBO(max_iarange_nesting=0) elbo.loss_and_grads(model, guide, data) # These golden values simply test agreement between parallel and sequential. expected_grads = { 2: { "transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]], "emission_probs": [[0.75, -0.75], [0.25, -0.25]], }, 3: { "transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]], "emission_probs": [[1.125, -1.125], [0.375, -0.375]], }, 10: { "transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]], "emission_probs": [[3.75, -3.75], [1.25, -1.25]], }, 20: { "transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]], "emission_probs": [[7.5, -7.5], [2.5, -2.5]], }, } for name, value in pyro.get_param_store().named_parameters(): actual = value.grad expected = torch.tensor(expected_grads[num_steps][name]) assert_equal(actual, expected, msg=''.join([ '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()), '\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()), ]))
def test_unweighted_mean_and_var(size, dtype): empirical_dist = Empirical() for i in range(5): empirical_dist.add(torch.ones(size, dtype=dtype) * i) true_mean = torch.ones(size) * 2 true_var = torch.ones(size) * 2 assert_equal(empirical_dist.mean, true_mean) assert_equal(empirical_dist.variance, true_var)
def test_log_pdf(dist): d = dist.pyro_dist for idx in dist.get_test_data_indices(): dist_params = dist.get_dist_params(idx) test_data = dist.get_test_data(idx) pyro_log_pdf = unwrap_variable(d.log_pdf(test_data, **dist_params))[0] scipy_log_pdf = dist.get_scipy_logpdf(idx) assert_equal(pyro_log_pdf, scipy_log_pdf)
def test_batch_log_pdf(dist): d = dist.pyro_dist for idx in dist.get_batch_data_indices(): dist_params = dist.get_dist_params(idx) test_data = dist.get_test_data(idx) logpdf_sum_pyro = unwrap_variable(torch.sum(d.batch_log_pdf(test_data, **dist_params)))[0] logpdf_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1)) assert_equal(logpdf_sum_pyro, logpdf_sum_np)
def test_double_type(test_data, alpha, beta): log_px_torch = dist.Beta(alpha, beta).log_prob(test_data).data assert isinstance(log_px_torch, torch.DoubleTensor) log_px_val = log_px_torch.numpy() log_px_np = sp.beta.logpdf( test_data.detach().cpu().numpy(), alpha.detach().cpu().numpy(), beta.detach().cpu().numpy()) assert_equal(log_px_val, log_px_np, prec=1e-4)
def test_sample_shape(dist): d = dist.pyro_dist for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) x_func = dist.pyro_dist.sample(**dist_params) x_obj = dist.pyro_dist_obj(**dist_params).sample() assert_equal(x_obj.size(), x_func.size()) with xfail_if_not_implemented(): assert(x_func.size() == d.shape(x_func, **dist_params))
def test_float_type(float_test_data, float_alpha, float_beta, test_data, alpha, beta): log_px_torch = dist.Beta(float_alpha, float_beta).log_prob(float_test_data).data assert isinstance(log_px_torch, torch.FloatTensor) log_px_val = log_px_torch.numpy() log_px_np = sp.beta.logpdf( test_data.detach().cpu().numpy(), alpha.detach().cpu().numpy(), beta.detach().cpu().numpy()) assert_equal(log_px_val, log_px_np, prec=1e-4)
def test_batch_log_prob(dist): if dist.scipy_arg_fn is None: pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__)) for idx in dist.get_batch_data_indices(): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) test_data = dist.get_test_data(idx) log_prob_sum_pyro = d.log_prob(test_data).sum().item() log_prob_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1)) assert_equal(log_prob_sum_pyro, log_prob_sum_np)
def test_enumerate_support(discrete_dist): expected_support = discrete_dist.expected_support expected_support_non_vec = discrete_dist.expected_support_non_vec if not expected_support: pytest.skip("enumerate_support not tested for distribution") Dist = discrete_dist.pyro_dist actual_support_non_vec = Dist(**discrete_dist.get_dist_params(0)).enumerate_support() actual_support = Dist(**discrete_dist.get_dist_params(-1)).enumerate_support() assert_equal(actual_support.data, torch.tensor(expected_support)) assert_equal(actual_support_non_vec.data, torch.tensor(expected_support_non_vec))
def test_scale_tril(): loc = torch.tensor([1.0, 2.0, 1.0, 2.0, 0.0]) D = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) W = torch.tensor([[1.0, -1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 1.0, 2.0, 4.0]]) cov = D.diag() + W.t().matmul(W) mvn = MultivariateNormal(loc, cov) lowrank_mvn = LowRankMultivariateNormal(loc, W, D) assert_equal(mvn.scale_tril, lowrank_mvn.scale_tril)
def assert_correct_dimensions(sample, ps, vs, one_hot): ps_shape = list(ps.data.size()) if isinstance(sample, torch.autograd.Variable): sample_shape = list(sample.data.size()) else: sample_shape = list(sample.shape) if one_hot and not vs: assert_equal(sample_shape, ps_shape) else: assert_equal(sample_shape, ps_shape[:-1] + [1])
def test_posterior_predictive(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run(num_trials) posterior_predictive = TracePredictive(model, mcmc_run, num_samples=10000).run(num_trials) marginal_return_vals = EmpiricalMarginal(posterior_predictive) assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
def test_trajectory(example): model, args = example q_f, p_f = velocity_verlet(args.q_i, args.p_i, model.potential_fn, args.step_size, args.num_steps) logger.info("initial q: {}".format(args.q_i)) logger.info("final q: {}".format(q_f)) assert_equal(q_f, args.q_f, args.prec) assert_equal(p_f, args.p_f, args.prec)
def test_log_prob(): loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0]) D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0]) W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]]) x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0]) cov = D.diag() + W.t().matmul(W) mvn = MultivariateNormal(loc, cov) lowrank_mvn = LowRankMultivariateNormal(loc, W, D) assert_equal(mvn.log_prob(x), lowrank_mvn.log_prob(x))
def test_replay_partial(self): guide_trace = poutine.trace(self.guide).get_trace() model_trace = poutine.trace(poutine.replay(self.model, guide_trace, sites=self.partial_sample_sites)).get_trace() for name in self.full_sample_sites.keys(): if name in self.partial_sample_sites: assert_equal(model_trace.nodes[name]["value"], guide_trace.nodes[name]["value"]) else: assert not eq(model_trace.nodes[name]["value"], guide_trace.nodes[name]["value"])
def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy): qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))] for i in range(1, depth): qs.append( pyro.param("q{}".format(i), torch.randn(2, 2).abs().detach().requires_grad_(), constraint=constraints.simplex)) qs.append(pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True))) qs = [q.unconstrained() for q in qs] data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype, device=qs[-1].device) def model(): x = pyro.sample("x0", dist.Categorical(pyro.param("q0"))) with pyro.plate("local", 3): for i in range(1, depth): x = pyro.sample( "x{}".format(i), dist.Categorical(pyro.param("q{}".format(i))[..., x, :])) with pyro.plate("data", 4): pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), obs=data) elbo = TraceEnum_ELBO(max_plate_nesting=max_plate_nesting) enum_model = config_enumerate(model, default="parallel", expand=False, num_samples=None, tmc=tmc_strategy) expected_loss = (-elbo.differentiable_loss(enum_model, lambda: None)).exp() expected_grads = grad(expected_loss, qs) tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = config_enumerate(model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp() actual_grads = grad(actual_loss, qs) prec = 0.05 assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ "\nexpected grad = {}".format( expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format( actual_grad.detach().cpu().numpy()), ]))
def test_support(self): s = dist.Categorical(self.d_ps).enumerate_support() assert_equal(s.data, self.support)
def test_trace_return(self): model_trace = poutine.trace(self.model).get_trace() assert_equal(model_trace.nodes["latent1"]["value"], model_trace.nodes["_RETURN"]["value"])
def test_elbo_mapdata(batch_size, map_type): # normal-normal: known covariance lam0 = Variable(torch.Tensor([0.1, 0.1])) # precision of prior mu0 = Variable(torch.Tensor([0.0, 0.5])) # prior mean # known precision of observation noise lam = Variable(torch.Tensor([6.0, 4.0])) data = [] sum_data = Variable(torch.zeros(2)) def add_data_point(x, y): data.append(Variable(torch.Tensor([x, y]))) sum_data.data.add_(data[-1].data) add_data_point(0.1, 0.21) add_data_point(0.16, 0.11) add_data_point(0.06, 0.31) add_data_point(-0.01, 0.07) add_data_point(0.23, 0.25) add_data_point(0.19, 0.18) add_data_point(0.09, 0.41) add_data_point(-0.04, 0.17) n_data = Variable(torch.Tensor([len(data)])) analytic_lam_n = lam0 + n_data.expand_as(lam) * lam analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n) analytic_mu_n = sum_data * (lam / analytic_lam_n) +\ mu0 * (lam0 / analytic_lam_n) verbose = True n_steps = 7000 if verbose: print("DOING ELBO TEST [bs = {}, map_type = {}]".format( batch_size, map_type)) pyro.clear_param_store() def model(): mu_latent = pyro.sample("mu_latent", dist.normal, mu0, torch.pow(lam0, -0.5)) if map_type == "list": pyro.map_data("aaa", data, lambda i, x: pyro.observe( "obs_%d" % i, dist.normal, x, mu_latent, torch.pow(lam, -0.5)), batch_size=batch_size) elif map_type == "tensor": tdata = torch.cat([xi.view(1, -1) for xi in data], 0) pyro.map_data("aaa", tdata, # XXX get batch size args to dist right lambda i, x: pyro.observe("obs", dist.normal, x, mu_latent, torch.pow(lam, -0.5)), batch_size=batch_size) else: for i, x in enumerate(data): pyro.observe('obs_%d' % i, dist.normal, x, mu_latent, torch.pow(lam, -0.5)) return mu_latent def guide(): mu_q = pyro.param("mu_q", Variable(analytic_mu_n.data + torch.Tensor([-0.18, 0.23]), requires_grad=True)) log_sig_q = pyro.param("log_sig_q", Variable( analytic_log_sig_n.data - torch.Tensor([-0.18, 0.23]), requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("mu_latent", dist.normal, mu_q, sig_q) if map_type == "list" or map_type is None: pyro.map_data("aaa", data, lambda i, x: None, batch_size=batch_size) elif map_type == "tensor": tdata = torch.cat([xi.view(1, -1) for xi in data], 0) # dummy map_data to do subsampling for observe pyro.map_data("aaa", tdata, lambda i, x: None, batch_size=batch_size) else: pass adam = optim.Adam({"lr": 0.0008, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True) for k in range(n_steps): svi.step() mu_error = torch.sum( torch.pow( analytic_mu_n - pyro.param("mu_q"), 2.0)) log_sig_error = torch.sum( torch.pow( analytic_log_sig_n - pyro.param("log_sig_q"), 2.0)) if verbose and k % 500 == 0: print("errors", mu_error.data.cpu().numpy()[0], log_sig_error.data.cpu().numpy()[0]) assert_equal(Variable(torch.zeros(1)), mu_error, prec=0.05) assert_equal(Variable(torch.zeros(1)), log_sig_error, prec=0.06)
def test_sample_dims(dim, probs): probs = modify_params_using_dims(probs, dim) sample = dist.Categorical(probs).sample() expected_shape = dist.Categorical(probs).shape() assert_equal(sample.size(), expected_shape)
def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0): n_repa_nodes = (torch.sum(self.which_nodes_reparam) if not reparameterized else self.N) logger.info( " - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST [reparameterized = %s; %d/%d] - - - - - " % (self.N, reparameterized, n_repa_nodes, self.N)) if self.N < 0: def array_to_string(y): return str( map(lambda x: "%.3f" % x.detach().cpu().numpy()[0], y)) logger.debug("lambdas: " + array_to_string(self.lambdas)) logger.debug("target_mus: " + array_to_string(self.target_mus[1:])) logger.debug("target_kappas: "******"lambda_posts: " + array_to_string(self.lambda_posts[1:])) logger.debug("lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts)) pyro.clear_param_store() adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)}) elbo = TraceGraph_ELBO() loss_and_grads = elbo.loss_and_grads # loss_and_grads = elbo.jit_loss_and_grads # This fails. svi = SVI(self.model, self.guide, adam, loss=elbo.loss, loss_and_grads=loss_and_grads) for step in range(n_steps): t0 = time.time() svi.step(reparameterized=reparameterized, difficulty=difficulty) if step % 5000 == 0 or step == n_steps - 1: kappa_errors, log_sig_errors, loc_errors = [], [], [] for k in range(1, self.N + 1): if k != self.N: kappa_error = param_mse("kappa_q_%d" % k, self.target_kappas[k]) kappa_errors.append(kappa_error) loc_errors.append( param_mse("loc_q_%d" % k, self.target_mus[k])) log_sig_error = param_mse( "log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k])) log_sig_errors.append(log_sig_error) max_errors = ( np.max(loc_errors), np.max(log_sig_errors), np.max(kappa_errors), ) min_errors = ( np.min(loc_errors), np.min(log_sig_errors), np.min(kappa_errors), ) mean_errors = ( np.mean(loc_errors), np.mean(log_sig_errors), np.mean(kappa_errors), ) logger.debug( "[max errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % max_errors) logger.debug( "[min errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % min_errors) logger.debug( "[mean errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % mean_errors) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_errors[0], prec=prec) assert_equal(0.0, max_errors[1], prec=prec) assert_equal(0.0, max_errors[2], prec=prec)
def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): n_samples = 200000 if batch_mode: sample_shape = torch.Size(()) else: sample_shape = torch.Size((n_samples,)) if mix_dist == GaussianScaleMixture: locs = torch.zeros(K, D, requires_grad=True) else: locs = torch.rand(K, D).requires_grad_(True) if mix_dist == GaussianScaleMixture: component_scale = 1.5 * torch.ones(K) + 0.5 * torch.rand(K) component_scale.requires_grad_(True) else: component_scale = torch.ones(K, requires_grad=True) if mix_dist == MixtureOfDiagNormals: coord_scale = torch.ones(K, D) + 0.5 * torch.rand(K, D) coord_scale.requires_grad_(True) else: coord_scale = torch.ones(D) + 0.5 * torch.rand(D) coord_scale.requires_grad_(True) if not flat_logits: component_logits = (1.5 * torch.rand(K)).requires_grad_(True) else: component_logits = (0.1 * torch.rand(K)).requires_grad_(True) omega = (0.2 * torch.ones(D) + 0.1 * torch.rand(D)).requires_grad_(False) _pis = torch.exp(component_logits) pis = _pis / _pis.sum() if cost_function == 'cosine': analytic1 = torch.cos((omega * locs).sum(-1)) analytic2 = torch.exp(-0.5 * torch.pow(omega * coord_scale * component_scale.unsqueeze(-1), 2.0).sum(-1)) analytic = (pis * analytic1 * analytic2).sum() analytic.backward() elif cost_function == 'quadratic': analytic = torch.pow(coord_scale * component_scale.unsqueeze(-1), 2.0).sum(-1) + torch.pow(locs, 2.0).sum(-1) analytic = (pis * analytic).sum() analytic.backward() analytic_grads = {} analytic_grads['locs'] = locs.grad.clone() analytic_grads['coord_scale'] = coord_scale.grad.clone() analytic_grads['component_logits'] = component_logits.grad.clone() analytic_grads['component_scale'] = component_scale.grad.clone() assert locs.grad.shape == locs.shape assert coord_scale.grad.shape == coord_scale.shape assert component_logits.grad.shape == component_logits.shape assert component_scale.grad.shape == component_scale.shape coord_scale.grad.zero_() component_logits.grad.zero_() locs.grad.zero_() component_scale.grad.zero_() if mix_dist == MixtureOfDiagNormalsSharedCovariance: params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} if batch_mode: locs = locs.unsqueeze(0).expand(n_samples, K, D) coord_scale = coord_scale.unsqueeze(0).expand(n_samples, D) component_logits = component_logits.unsqueeze(0).expand(n_samples, K) dist_params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} else: dist_params = params elif mix_dist == MixtureOfDiagNormals: params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} if batch_mode: locs = locs.unsqueeze(0).expand(n_samples, K, D) coord_scale = coord_scale.unsqueeze(0).expand(n_samples, K, D) component_logits = component_logits.unsqueeze(0).expand(n_samples, K) dist_params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} else: dist_params = params elif mix_dist == GaussianScaleMixture: params = {'coord_scale': coord_scale, 'component_logits': component_logits, 'component_scale': component_scale} if batch_mode: return # distribution does not support batched parameters else: dist_params = params dist = mix_dist(**dist_params) z = dist.rsample(sample_shape=sample_shape) assert z.shape == (n_samples, D) if cost_function == 'cosine': cost = torch.cos((omega * z).sum(-1)).sum() / float(n_samples) elif cost_function == 'quadratic': cost = torch.pow(z, 2.0).sum() / float(n_samples) cost.backward() assert_equal(analytic, cost, prec=0.1, msg='bad cost function evaluation for {} test (expected {}, got {})'.format( mix_dist.__name__, analytic.item(), cost.item())) logger.debug("analytic_grads_logit: {}" .format(analytic_grads['component_logits'].detach().cpu().numpy())) for param_name, param in params.items(): assert_equal(param.grad, analytic_grads[param_name], prec=0.1, msg='bad {} grad for {} (expected {}, got {})'.format( param_name, mix_dist.__name__, analytic_grads[param_name], param.grad))
def test_elbo_mapdata(batch_size, map_type): # normal-normal: known covariance lam0 = torch.tensor([0.1, 0.1]) # precision of prior loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise lam = torch.tensor([6.0, 4.0]) data = [] sum_data = torch.zeros(2) def add_data_point(x, y): data.append(torch.tensor([x, y])) sum_data.data.add_(data[-1].data) add_data_point(0.1, 0.21) add_data_point(0.16, 0.11) add_data_point(0.06, 0.31) add_data_point(-0.01, 0.07) add_data_point(0.23, 0.25) add_data_point(0.19, 0.18) add_data_point(0.09, 0.41) add_data_point(-0.04, 0.17) data = torch.stack(data) n_data = torch.tensor([float(len(data))]) analytic_lam_n = lam0 + n_data.expand_as(lam) * lam analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n) analytic_loc_n = sum_data * (lam / analytic_lam_n) +\ loc0 * (lam0 / analytic_lam_n) n_steps = 7000 logger.debug("DOING ELBO TEST [bs = {}, map_type = {}]".format( batch_size, map_type)) pyro.clear_param_store() def model(): loc_latent = pyro.sample( "loc_latent", dist.Normal(loc0, torch.pow(lam0, -0.5)).to_event(1)) if map_type == "iplate": for i in pyro.plate("aaa", len(data), batch_size): pyro.sample("obs_%d" % i, dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), obs=data[i]), elif map_type == "plate": with pyro.plate("aaa", len(data), batch_size) as ind: pyro.sample("obs", dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), obs=data[ind]), else: for i, x in enumerate(data): pyro.sample('obs_%d' % i, dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), obs=x) return loc_latent def guide(): loc_q = pyro.param( "loc_q", analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23])) log_sig_q = pyro.param( "log_sig_q", analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23])) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", dist.Normal(loc_q, sig_q).to_event(1)) if map_type == "iplate" or map_type is None: for i in pyro.plate("aaa", len(data), batch_size): pass elif map_type == "plate": # dummy plate to do subsampling for observe with pyro.plate("aaa", len(data), batch_size): pass else: pass adam = optim.Adam({"lr": 0.0008, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() loc_error = torch.sum( torch.pow(analytic_loc_n - pyro.param("loc_q"), 2.0)) log_sig_error = torch.sum( torch.pow(analytic_log_sig_n - pyro.param("log_sig_q"), 2.0)) if k % 500 == 0: logger.debug("errors - {}, {}".format(loc_error, log_sig_error)) assert_equal(loc_error.item(), 0, prec=0.05) assert_equal(log_sig_error.item(), 0, prec=0.06)
def test_persistent_independent_subproblems(num_objects, num_frames, num_detections, bp_iters): # solve a random assignment problem exists_logits_1 = -2 * torch.rand(num_objects) assign_logits_1 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 assignment_1 = MarginalAssignmentPersistent(exists_logits_1, assign_logits_1, bp_iters) exists_probs_1 = assignment_1.exists_dist.probs assign_probs_1 = assignment_1.assign_dist.probs # solve another random assignment problem exists_logits_2 = -2 * torch.rand(num_objects) assign_logits_2 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 assignment_2 = MarginalAssignmentPersistent(exists_logits_2, assign_logits_2, bp_iters) exists_probs_2 = assignment_2.exists_dist.probs assign_probs_2 = assignment_2.assign_dist.probs # solve a unioned assignment problem exists_logits = torch.cat([exists_logits_1, exists_logits_2]) assign_logits = torch.full((num_frames, num_detections * 2, num_objects * 2), -INF) assign_logits[:, :num_detections, :num_objects] = assign_logits_1 assign_logits[:, num_detections:, num_objects:] = assign_logits_2 assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters) exists_probs = assignment.exists_dist.probs assign_probs = assignment.assign_dist.probs # check agreement assert_equal(exists_probs_1, exists_probs[:num_objects]) assert_equal(exists_probs_2, exists_probs[num_objects:]) assert_equal(assign_probs_1[:, :, :-1], assign_probs[:, :num_detections, :num_objects]) assert_equal(assign_probs_1[:, :, -1], assign_probs[:, :num_detections, -1]) assert_equal(assign_probs_2[:, :, :-1], assign_probs[:, num_detections:, num_objects:-1]) assert_equal(assign_probs_2[:, :, -1], assign_probs[:, num_detections:, -1])
def test_replay_full(self): guide_trace = poutine.trace(self.guide).get_trace() model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace() for name in self.full_sample_sites.keys(): assert_equal(model_trace.nodes[name]["value"], guide_trace.nodes[name]["value"])
def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T): torch.set_default_tensor_type('torch.DoubleTensor') if model_class == 'lgssm': model = GenericLGSSM(state_dim=state_dim, obs_dim=obs_dim, obs_noise_scale_init=0.1 + torch.rand(obs_dim)) elif model_class == 'lgssmgp': model = GenericLGSSMWithGPNoiseModel(state_dim=state_dim, obs_dim=obs_dim, nu=1.5, obs_noise_scale_init=0.1 + torch.rand(obs_dim)) # with these hyperparameters we essentially turn off the GP contributions model.kernel.length_scale = 1.0e-6 * torch.ones(obs_dim) model.kernel.kernel_scale = 1.0e-6 * torch.ones(obs_dim) targets = torch.randn(T, obs_dim) filtering_state = model._filter(targets) actual_loc, actual_cov = model._forecast(3, filtering_state, include_observation_noise=False) obs_matrix = model.obs_matrix if model_class == 'lgssm' else model.z_obs_matrix trans_matrix = model.trans_matrix if model_class == 'lgssm' else model.z_trans_matrix trans_matrix_sq = torch.mm(trans_matrix, trans_matrix) trans_matrix_cubed = torch.mm(trans_matrix_sq, trans_matrix) trans_obs = torch.mm(trans_matrix, obs_matrix) trans_trans_obs = torch.mm(trans_matrix_sq, obs_matrix) trans_trans_trans_obs = torch.mm(trans_matrix_cubed, obs_matrix) # we only compute contributions for the state space portion for lgssmgp fs_loc = filtering_state.loc if model_class == 'lgssm' else filtering_state.loc[-state_dim:] predicted_mean1 = torch.mm(fs_loc.unsqueeze(-2), trans_obs).squeeze(-2) predicted_mean2 = torch.mm(fs_loc.unsqueeze(-2), trans_trans_obs).squeeze(-2) predicted_mean3 = torch.mm(fs_loc.unsqueeze(-2), trans_trans_trans_obs).squeeze(-2) # check predicted means for 3 timesteps assert_equal(actual_loc[0], predicted_mean1) assert_equal(actual_loc[1], predicted_mean2) assert_equal(actual_loc[2], predicted_mean3) # check predicted covariances for 3 timesteps fs_covar, process_covar = None, None if model_class == 'lgssm': process_covar = model._get_trans_dist().covariance_matrix fs_covar = filtering_state.covariance_matrix elif model_class == 'lgssmgp': # we only compute contributions for the state space portion process_covar = model.trans_noise_scale_sq.diag_embed() fs_covar = filtering_state.covariance_matrix[-state_dim:, -state_dim:] predicted_covar1 = torch.mm(trans_obs.t(), torch.mm(fs_covar, trans_obs)) + \ torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) predicted_covar2 = torch.mm(trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_obs)) + \ torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + \ torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) predicted_covar3 = torch.mm(trans_trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_trans_obs)) + \ torch.mm(trans_trans_obs.t(), torch.mm(process_covar, trans_trans_obs)) + \ torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + \ torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) assert_equal(actual_cov[0], predicted_covar1) assert_equal(actual_cov[1], predicted_covar2) assert_equal(actual_cov[2], predicted_covar3)
def test_log_prob_sum(self): log_px_torch = dist.Categorical(self.probs).log_prob(self.test_data).sum().item() log_px_np = float(sp.multinomial.logpmf(np.array([0, 0, 1]), 1, self.probs.detach().cpu().numpy())) assert_equal(log_px_torch, log_px_np, prec=1e-4)
def test_mask(batch_dim, event_dim, mask_dim): # Construct base distribution. shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim]) batch_shape = shape[:batch_dim] mask_shape = batch_shape[batch_dim - mask_dim:] base_dist = Bernoulli(0.1).expand_by(shape).to_event(event_dim) # Construct masked distribution. mask = checker_mask(mask_shape) dist = base_dist.mask(mask) # Check shape. sample = base_dist.sample() assert dist.batch_shape == base_dist.batch_shape assert dist.event_shape == base_dist.event_shape assert sample.shape == sample.shape assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape # Check values. assert_equal(dist.mean, base_dist.mean) assert_equal(dist.variance, base_dist.variance) assert_equal(dist.log_prob(sample), scale_and_mask(base_dist.log_prob(sample), mask=mask)) assert_equal(dist.score_parts(sample), base_dist.score_parts(sample).scale_and_mask(mask=mask), prec=0) if not dist.event_shape: assert_equal(dist.enumerate_support(), base_dist.enumerate_support()) assert_equal(dist.enumerate_support(expand=True), base_dist.enumerate_support(expand=True)) assert_equal(dist.enumerate_support(expand=False), base_dist.enumerate_support(expand=False))
def test_tmc_normals_chain_iwae(depth, num_samples, max_plate_nesting, reparameterized, guide_type, expand, tmc_strategy): # compare iwae and tmc q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True)) qs = (q2.unconstrained(), ) def model(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1))) def factorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): pyro.sample("x{}".format(i), Normal(0., math.sqrt(float(i + 1) / depth))) def nonfactorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) guide = factorized_guide if guide_type == "factorized" else \ nonfactorized_guide if guide_type == "nonfactorized" else \ poutine.block(model, hide_fn=lambda msg: msg["type"] == "sample" and msg["is_observed"]) flat_num_samples = num_samples**min(depth, 2) # don't use too many, expensive vectorized_log_weights, _, _ = vectorized_importance_weights( model, guide, True, max_plate_nesting=max_plate_nesting, num_samples=flat_num_samples) assert vectorized_log_weights.shape == (flat_num_samples, ) expected_loss = (vectorized_log_weights.logsumexp(dim=-1) - math.log(float(flat_num_samples))).exp() expected_grads = grad(expected_loss, qs) tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = config_enumerate(model, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) tmc_guide = config_enumerate(guide, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) actual_loss = ( -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp() actual_grads = grad(actual_loss, qs) assert_equal(actual_loss, expected_loss, prec=0.05, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) grad_prec = 0.05 if reparameterized else 0.1 for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_equal(actual_grad, expected_grad, prec=grad_prec, msg="".join([ "\nexpected grad = {}".format( expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format( actual_grad.detach().cpu().numpy()), ]))
def do_elbo_test( self, reparameterized, n_steps, lr, prec, beta1, difficulty=1.0, model_permutation=False, ): n_repa_nodes = (torch.sum(self.which_nodes_reparam) if not reparameterized else len(self.q_topo_sort)) logger.info(( " - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " + "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -" ) % ( self.N, (2**self.N) - 1, reparameterized, n_repa_nodes, len(self.q_topo_sort), model_permutation, )) pyro.clear_param_store() # check graph structure is as expected but only for N=2 if self.N == 2: guide_trace = pyro.poutine.trace( self.guide, graph_type="dense").get_trace( reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty, ) expected_nodes = set([ "log_sig_1R", "kappa_1_1L", "_INPUT", "constant_term_loc_latent_1R", "_RETURN", "loc_latent_1R", "loc_latent_1", "constant_term_loc_latent_1", "loc_latent_1L", "constant_term_loc_latent_1L", "log_sig_1L", "kappa_1_1R", "kappa_1R_1L", "log_sig_1", ]) expected_edges = set([ ("loc_latent_1R", "loc_latent_1"), ("loc_latent_1L", "loc_latent_1R"), ("loc_latent_1L", "loc_latent_1"), ]) assert expected_nodes == set(guide_trace.nodes) assert expected_edges == set(guide_trace.edges) adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)}) svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO()) for step in range(n_steps): t0 = time.time() svi.step( reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty, ) if step % 5000 == 0 or step == n_steps - 1: log_sig_errors = [] for node in self.target_lambdas: target_log_sig = -0.5 * torch.log( self.target_lambdas[node]) log_sig_error = param_mse("log_sig_" + node, target_log_sig) log_sig_errors.append(log_sig_error) max_log_sig_error = np.max(log_sig_errors) min_log_sig_error = np.min(log_sig_errors) mean_log_sig_error = np.mean(log_sig_errors) leftmost_node = self.q_topo_sort[0] leftmost_constant_error = param_mse( "constant_term_" + leftmost_node, self.target_leftmost_constant) almost_leftmost_constant_error = param_mse( "constant_term_" + leftmost_node[:-1] + "R", self.target_almost_leftmost_constant, ) logger.debug( "[mean function constant errors (partial)] %.4f %.4f" % (leftmost_constant_error, almost_leftmost_constant_error)) logger.debug( "[min/mean/max log(scale) errors] %.4f %.4f %.4f" % (min_log_sig_error, mean_log_sig_error, max_log_sig_error)) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_log_sig_error, prec=prec) assert_equal(0.0, leftmost_constant_error, prec=prec) assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
def test_log_pdf(self): log_px_torch = dist.delta.log_pdf(self.test_data, self.v).data assert_equal(torch.sum(log_px_torch), 0)
def test_undo_uncondition(self): unconditioned_model = poutine.uncondition(self.model) reconditioned_model = pyro.condition(unconditioned_model, {"obs": torch.ones(2)}) reconditioned_trace = poutine.trace(reconditioned_model).get_trace() assert_equal(reconditioned_trace.nodes["obs"]["value"], torch.ones(2))
def test_uncondition(self): unconditioned_model = poutine.uncondition(self.model) unconditioned_trace = poutine.trace(unconditioned_model).get_trace() conditioned_trace = poutine.trace(self.model).get_trace() assert_equal(conditioned_trace.nodes["obs"]["value"], torch.ones(2)) assert_not_equal(unconditioned_trace.nodes["obs"]["value"], torch.ones(2))
def test_timeseries_models(model, nu_statedim, obs_dim, T): torch.set_default_tensor_type('torch.DoubleTensor') dt = 0.1 + torch.rand(1).item() if model == 'lcmgp': num_gps = 2 gp = LinearlyCoupledMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, num_gps=num_gps, log_length_scale_init=torch.randn(num_gps), log_kernel_scale_init=torch.randn(num_gps), log_obs_noise_scale_init=torch.randn(obs_dim)) elif model == 'imgp': gp = IndependentMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, log_length_scale_init=torch.randn(obs_dim), log_kernel_scale_init=torch.randn(obs_dim), log_obs_noise_scale_init=torch.randn(obs_dim)) elif model == 'glgssm': gp = GenericLGSSM(state_dim=nu_statedim, obs_dim=obs_dim, log_obs_noise_scale_init=torch.randn(obs_dim)) elif model == 'ssmgp': state_dim = {0.5: 4, 1.5: 3, 2.5: 2}[nu_statedim] gp = GenericLGSSMWithGPNoiseModel(nu=nu_statedim, state_dim=state_dim, obs_dim=obs_dim, log_obs_noise_scale_init=torch.randn(obs_dim)) elif model == 'dmgp': gp = DependentMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, log_length_scale_init=torch.randn(obs_dim)) targets = torch.randn(T, obs_dim) gp_log_prob = gp.log_prob(targets) if model == 'imgp': assert gp_log_prob.shape == (obs_dim,) else: assert gp_log_prob.dim() == 0 # compare matern log probs to vanilla GP result via multivariate normal if model == 'imgp': times = dt * torch.arange(T).double() for dim in range(obs_dim): lengthscale = gp.kernel.log_length_scale.exp()[dim] variance = (2.0 * gp.kernel.log_kernel_scale).exp()[dim] obs_noise = (2.0 * gp.log_obs_noise_scale).exp()[dim] kernel = {0.5: pyro.contrib.gp.kernels.Exponential, 1.5: pyro.contrib.gp.kernels.Matern32, 2.5: pyro.contrib.gp.kernels.Matern52}[nu_statedim] kernel = kernel(input_dim=1, lengthscale=lengthscale, variance=variance) kernel = kernel(times) + obs_noise * torch.eye(T) mvn = torch.distributions.MultivariateNormal(torch.zeros(T), kernel) mvn_log_prob = mvn.log_prob(targets[:, dim]) assert_equal(mvn_log_prob, gp_log_prob[dim], prec=1e-4) for S in [1, 5]: if model in ['imgp', 'lcmgp', 'dmgp']: dts = torch.rand(S).cumsum(dim=-1) predictive = gp.forecast(targets, dts) else: predictive = gp.forecast(targets, S) assert predictive.loc.shape == (S, obs_dim) if model == 'imgp': assert predictive.scale.shape == (S, obs_dim) # assert monotonic increase of predictive noise if S > 1: delta = predictive.scale[1:S, :] - predictive.scale[0:S-1, :] assert (delta > 0.0).sum() == (S - 1) * obs_dim else: assert predictive.covariance_matrix.shape == (S, obs_dim, obs_dim) # assert monotonic increase of predictive noise if S > 1: dets = predictive.covariance_matrix.det() delta = dets[1:S] - dets[0:S-1] assert (delta > 0.0).sum() == (S - 1) if model in ['imgp', 'lcmgp', 'dmgp']: # the distant future dts = torch.tensor([500.0]) predictive = gp.forecast(targets, dts) # assert mean reverting behavior for GP models assert_equal(predictive.loc, torch.zeros(1, obs_dim))
def test_support(self): s = dist.one_hot_categorical.enumerate_support(self.d_ps) assert_equal(s.data, self.support_one_hot)
def test_batch_log_dims(dim, probs): probs = modify_params_using_dims(probs, dim) log_prob_shape = torch.Size((3,) + dist.Categorical(probs).batch_shape) support = dist.Categorical(probs).enumerate_support() log_prob = dist.Categorical(probs).log_prob(support) assert_equal(log_prob.size(), log_prob_shape)
def test_support_non_vectorized(self): s = dist.one_hot_categorical.enumerate_support(self.d_ps[0].squeeze(0)) assert_equal(s.data, self.support_one_hot_non_vec)
def test_support_dims(dim, probs): probs = modify_params_using_dims(probs, dim) support = dist.Categorical(probs).enumerate_support() assert_equal(support.size(), torch.Size((probs.size(-1),) + probs.size()[:-1]))
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23, include_triple, include_z1): guide_trace = poutine.trace(big_model_guide, graph_type="dense").get_trace( include_obs=False, include_inner_1=include_inner_1, include_single=include_single, flip_c23=flip_c23, include_triple=include_triple, include_z1=include_z1) model_trace = poutine.trace(poutine.replay(big_model_guide, trace=guide_trace), graph_type="dense").get_trace( include_obs=True, include_inner_1=include_inner_1, include_single=include_single, flip_c23=flip_c23, include_triple=include_triple, include_z1=include_z1) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_nodes_full_model = { 'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'}, 'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'}, 'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'}, 'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'}, 'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'}, 'c2': {'obs', 'd1', 'c3', 'd2', 'c2'} } if not include_triple and include_inner_1 and include_single and not flip_c23: assert (dc_nodes == expected_nodes_full_model) expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']) expected_b1 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) expected_b1 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False) if include_inner_1: expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum(0) expected_b1 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum(0) expected_b1 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum(0) assert_equal(expected_b1, dc['b1'], prec=1.0e-6) if include_single: expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob']) expected_b0 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() expected_b0 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum() expected_b0 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum() expected_b0 += model_trace.nodes['obs']['log_prob'].sum() if include_inner_1: expected_b0 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum() expected_b0 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum() expected_b0 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum() assert_equal(expected_b0, dc['b0'], prec=1.0e-6) assert dc['b0'].size() == (5, ) if include_inner_1: expected_c3 = (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']) expected_c3 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) expected_c3 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0) expected_c2 = (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']) expected_c2 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) expected_c2 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0) expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) if flip_c23: expected_c3 += model_trace.nodes['c2'][ 'log_prob'] - guide_trace.nodes['c2']['log_prob'] expected_c2 += model_trace.nodes['c3']['log_prob'] else: expected_c2 += model_trace.nodes['c3'][ 'log_prob'] - guide_trace.nodes['c3']['log_prob'] expected_c2 += model_trace.nodes['c2'][ 'log_prob'] - guide_trace.nodes['c2']['log_prob'] expected_c1 += expected_c3 assert_equal(expected_c1, dc['c1'], prec=1.0e-6) assert_equal(expected_c2, dc['c2'], prec=1.0e-6) assert_equal(expected_c3, dc['c3'], prec=1.0e-6) expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes[ 'd1']['log_prob'] expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes[ 'd2']['log_prob'] expected_d1 += model_trace.nodes['obs']['log_prob'] expected_d2 = (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']) expected_d2 += model_trace.nodes['obs']['log_prob'] if include_triple: expected_z0 = dc['a1'] + model_trace.nodes['z0'][ 'log_prob'] - guide_trace.nodes['z0']['log_prob'] assert_equal(expected_z0, dc['z0'], prec=1.0e-6) assert_equal(expected_d2, dc['d2'], prec=1.0e-6) assert_equal(expected_d1, dc['d1'], prec=1.0e-6) assert dc['b1'].size() == (2, ) assert dc['d2'].size() == (4, 2) for k in dc: assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k])
def test_mean_and_var(self): torch_samples = [dist.Categorical(self.probs).sample().detach().cpu().numpy() for _ in range(self.n_samples)] _, counts = np.unique(torch_samples, return_counts=True) computed_mean = float(counts[0]) / self.n_samples assert_equal(computed_mean, self.analytic_mean.detach().cpu().numpy()[0], prec=0.05)
def test_NcpContinuous(): framerate = 100 # Hz dt = 1.0 / framerate d = 3 ncp = NcpContinuous(dimension=d, sv2=2.0) assert ncp.dimension == d assert ncp.dimension_pv == 2*d assert ncp.num_process_noise_parameters == 1 x = torch.rand(d) y = ncp(x, dt) assert_equal(y, x) dx = ncp.geodesic_difference(x, y) assert_equal(dx, torch.zeros(d)) x_pv = ncp.mean2pv(x) assert len(x_pv) == 6 assert_equal(x, x_pv[:d]) assert_equal(torch.zeros(d), x_pv[d:]) P = torch.eye(d) P_pv = ncp.cov2pv(P) assert P_pv.shape == (2*d, 2*d) P_pv_ref = torch.zeros((2*d, 2*d)) P_pv_ref[:d, :d] = P assert_equal(P_pv_ref, P_pv) Q = ncp.process_noise_cov(dt) Q1 = ncp.process_noise_cov(dt) # Test caching. assert_equal(Q, Q1) assert Q1.shape == (d, d) assert_cov_validity(Q1) dx = ncp.process_noise_dist(dt).sample() assert dx.shape == (ncp.dimension,)
def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0): n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized else self.N logger.info(" - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST [reparameterized = %s; %d/%d] - - - - - " % (self.N, reparameterized, n_repa_nodes, self.N)) if self.N < 0: def array_to_string(y): return str(map(lambda x: "%.3f" % x.data.cpu().numpy()[0], y)) logger.debug("lambdas: " + array_to_string(self.lambdas)) logger.debug("target_mus: " + array_to_string(self.target_mus[1:])) logger.debug("target_kappas: "******"lambda_posts: " + array_to_string(self.lambda_posts[1:])) logger.debug("lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts)) pyro.clear_param_store() def model(*args, **kwargs): next_mean = self.mu0 for k in range(1, self.N + 1): latent_dist = dist.Normal(next_mean, torch.pow(self.lambdas[k - 1], -0.5)) mu_latent = pyro.sample("mu_latent_%d" % k, latent_dist) next_mean = mu_latent mu_N = next_mean for i, x in enumerate(self.data): pyro.observe("obs_%d" % i, dist.normal, x, mu_N, torch.pow(self.lambdas[self.N], -0.5)) return mu_N def guide(*args, **kwargs): previous_sample = None for k in reversed(range(1, self.N + 1)): mu_q = pyro.param("mu_q_%d" % k, Variable(self.target_mus[k].data + difficulty * (0.1 * torch.randn(1) - 0.53), requires_grad=True)) log_sig_q = pyro.param("log_sig_q_%d" % k, Variable(-0.5 * torch.log(self.lambda_posts[k]).data + difficulty * (0.1 * torch.randn(1) - 0.53), requires_grad=True)) sig_q = torch.exp(log_sig_q) kappa_q = None if k == self.N \ else pyro.param("kappa_q_%d" % k, Variable(self.target_kappas[k].data + difficulty * (0.1 * torch.randn(1) - 0.53), requires_grad=True)) mean_function = mu_q if k == self.N else kappa_q * previous_sample + mu_q node_flagged = True if self.which_nodes_reparam[k - 1] == 1.0 else False normal = dist.normal if reparameterized or node_flagged else fakes.nonreparameterized_normal mu_latent = pyro.sample("mu_latent_%d" % k, normal, mean_function, sig_q, baseline=dict(use_decaying_avg_baseline=True)) previous_sample = mu_latent return previous_sample adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True) for step in range(n_steps): t0 = time.time() svi.step() if step % 5000 == 0 or step == n_steps - 1: kappa_errors, log_sig_errors, mu_errors = [], [], [] for k in range(1, self.N + 1): if k != self.N: kappa_error = param_mse("kappa_q_%d" % k, self.target_kappas[k]) kappa_errors.append(kappa_error) mu_errors.append(param_mse("mu_q_%d" % k, self.target_mus[k])) log_sig_error = param_mse("log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k])) log_sig_errors.append(log_sig_error) max_errors = (np.max(mu_errors), np.max(log_sig_errors), np.max(kappa_errors)) min_errors = (np.min(mu_errors), np.min(log_sig_errors), np.min(kappa_errors)) mean_errors = (np.mean(mu_errors), np.mean(log_sig_errors), np.mean(kappa_errors)) logger.debug("[max errors] (mu, log_sigma, kappa) = (%.4f, %.4f, %.4f)" % max_errors) logger.debug("[min errors] (mu, log_sigma, kappa) = (%.4f, %.4f, %.4f)" % min_errors) logger.debug("[mean errors] (mu, log_sigma, kappa) = (%.4f, %.4f, %.4f)" % mean_errors) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_errors[0], prec=prec) assert_equal(0.0, max_errors[1], prec=prec) assert_equal(0.0, max_errors[2], prec=prec)
def test_tmc_normals_chain_gradient(depth, num_samples, max_plate_nesting, expand, guide_type, reparameterized, tmc_strategy): # compare reparameterized and nonreparameterized gradient estimates q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True)) qs = (q2.unconstrained(), ) def model(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1))) def factorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): pyro.sample("x{}".format(i), Normal(0., math.sqrt(float(i + 1) / depth))) def nonfactorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = config_enumerate(model, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) guide = factorized_guide if guide_type == "factorized" else \ nonfactorized_guide if guide_type == "nonfactorized" else \ lambda *args: None tmc_guide = config_enumerate(guide, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) # gold values from Funsor expected_grads = (torch.tensor({ 1: 0.0999, 2: 0.0860, 3: 0.0802, 4: 0.0771 }[depth]), ) # convert to linear space for unbiasedness actual_loss = ( -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp() actual_grads = grad(actual_loss, qs) grad_prec = 0.05 if reparameterized else 0.1 for actual_grad, expected_grad in zip(actual_grads, expected_grads): print(actual_loss) assert_equal(actual_grad, expected_grad, prec=grad_prec, msg="".join([ "\nexpected grad = {}".format( expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format( actual_grad.detach().cpu().numpy()), ]))
def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1, difficulty=1.0, model_permutation=False): n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \ else len(self.q_topo_sort) logger.info((" - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " + "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -") % (self.N, (2 ** self.N) - 1, reparameterized, n_repa_nodes, len(self.q_topo_sort), model_permutation)) pyro.clear_param_store() def model(*args, **kwargs): top_latent_dist = dist.Normal(self.mu0, torch.pow(self.lambdas[0], -0.5)) previous_names = ["mu_latent_1"] top_latent = pyro.sample(previous_names[0], top_latent_dist) previous_latents_and_names = list(zip([top_latent], previous_names)) # for sampling model variables in different sequential orders def permute(x, n): if model_permutation: return [x[self.model_permutations[n - 1][i]] for i in range(len(x))] return x def unpermute(x, n): if model_permutation: return [x[self.model_unpermutations[n - 1][i]] for i in range(len(x))] return x for n in range(2, self.N + 1): new_latents_and_names = [] for prev_latent, prev_name in permute(previous_latents_and_names, n - 1): latent_dist = dist.Normal(prev_latent, torch.pow(self.lambdas[n - 1], -0.5)) couple = [] for LR in ['L', 'R']: new_name = prev_name + LR mu_latent_LR = pyro.sample(new_name, latent_dist) couple.append([mu_latent_LR, new_name]) new_latents_and_names.append(couple) _previous_latents_and_names = unpermute(new_latents_and_names, n - 1) previous_latents_and_names = [] for x in _previous_latents_and_names: previous_latents_and_names.append(x[0]) previous_latents_and_names.append(x[1]) for i, data_i in enumerate(self.data): for k, x in enumerate(data_i): pyro.observe("obs_%s_%d" % (previous_latents_and_names[i][1], k), dist.normal, x, previous_latents_and_names[i][0], torch.pow(self.lambdas[-1], -0.5)) return top_latent def guide(*args, **kwargs): latents_dict = {} n_nodes = len(self.q_topo_sort) for i, node in enumerate(self.q_topo_sort): deps = self.q_dag.predecessors(node) node_suffix = node[10:] log_sig_node = pyro.param("log_sig_" + node_suffix, Variable(-0.5 * torch.log(self.target_lambdas[node_suffix]).data + difficulty * (torch.Tensor([-0.3]) - 0.3 * (torch.randn(1) ** 2)), requires_grad=True)) mean_function_node = pyro.param("constant_term_" + node, Variable(self.mu0.data + torch.Tensor([difficulty * i / n_nodes]), requires_grad=True)) for dep in deps: kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[10:], Variable(torch.Tensor([0.5 + difficulty * i / n_nodes]), requires_grad=True)) mean_function_node = mean_function_node + kappa_dep * latents_dict[dep] node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False normal = dist.normal if reparameterized or node_flagged else fakes.nonreparameterized_normal latent_node = pyro.sample(node, normal, mean_function_node, torch.exp(log_sig_node), baseline=dict(use_decaying_avg_baseline=True, baseline_beta=0.96)) latents_dict[node] = latent_node return latents_dict['mu_latent_1'] # check graph structure is as expected but only for N=2 if self.N == 2: guide_trace = pyro.poutine.trace(guide, graph_type="dense").get_trace() expected_nodes = set(['log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_mu_latent_1R', '_RETURN', 'mu_latent_1R', 'mu_latent_1', 'constant_term_mu_latent_1', 'mu_latent_1L', 'constant_term_mu_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L', 'log_sig_1']) expected_edges = set([('mu_latent_1R', 'mu_latent_1'), ('mu_latent_1L', 'mu_latent_1R'), ('mu_latent_1L', 'mu_latent_1')]) assert expected_nodes == set(guide_trace.nodes) assert expected_edges == set(guide_trace.edges) adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True) for step in range(n_steps): t0 = time.time() svi.step() if step % 5000 == 0 or step == n_steps - 1: log_sig_errors = [] for node in self.target_lambdas: target_log_sig = -0.5 * torch.log(self.target_lambdas[node]) log_sig_error = param_mse('log_sig_' + node, target_log_sig) log_sig_errors.append(log_sig_error) max_log_sig_error = np.max(log_sig_errors) min_log_sig_error = np.min(log_sig_errors) mean_log_sig_error = np.mean(log_sig_errors) leftmost_node = self.q_topo_sort[0] leftmost_constant_error = param_mse('constant_term_' + leftmost_node, self.target_leftmost_constant) almost_leftmost_constant_error = param_mse('constant_term_' + leftmost_node[:-1] + 'R', self.target_almost_leftmost_constant) logger.debug("[mean function constant errors (partial)] %.4f %.4f" % (leftmost_constant_error, almost_leftmost_constant_error)) logger.debug("[min/mean/max log(sigma) errors] %.4f %.4f %.4f" % (min_log_sig_error, mean_log_sig_error, max_log_sig_error)) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_log_sig_error, prec=prec) assert_equal(0.0, leftmost_constant_error, prec=prec) assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
def _test_plate_in_elbo(self, n_superfluous_top, n_superfluous_bottom, n_steps): pyro.clear_param_store() self.data_tensor = torch.zeros(9, 2) for _out in range(self.n_outer): for _in in range(self.n_inner): self.data_tensor[3 * _out + _in, :] = self.data[_out][_in] self.data_as_list = [self.data_tensor[0:4, :], self.data_tensor[4:7, :], self.data_tensor[7:9, :]] def model(): loc_latent = pyro.sample("loc_latent", fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)) .to_event(1)) for i in pyro.plate("outer", 3): x_i = self.data_as_list[i] with pyro.plate("inner_%d" % i, x_i.size(0)): for k in range(n_superfluous_top): z_i_k = pyro.sample("z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i])) assert z_i_k.shape == (4 - i,) obs_i = pyro.sample("obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5)) .to_event(1), obs=x_i) assert obs_i.shape == (4 - i, 2) for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): z_i_k = pyro.sample("z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i])) assert z_i_k.shape == (4 - i,) pt_loc_baseline = torch.nn.Linear(1, 1) pt_superfluous_baselines = [] for k in range(n_superfluous_top + n_superfluous_bottom): pt_superfluous_baselines.extend([torch.nn.Linear(2, 4), torch.nn.Linear(2, 3), torch.nn.Linear(2, 2)]) def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.07) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline) baseline_value = trivial_baseline(torch.ones(1)).squeeze() loc_latent = pyro.sample("loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), infer=dict(baseline=dict(baseline_value=baseline_value))) for i in pyro.plate("outer", 3): with pyro.plate("inner_%d" % i, 4 - i): for k in range(n_superfluous_top + n_superfluous_bottom): z_baseline = pyro.module("z_baseline_%d_%d" % (i, k), pt_superfluous_baselines[3 * k + i]) baseline_value = z_baseline(loc_latent.detach()) mean_i = pyro.param("mean_%d_%d" % (i, k), 0.5 * torch.ones(4 - i)) z_i_k = pyro.sample("z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(mean_i, 1), infer=dict(baseline=dict(baseline_value=baseline_value))) assert z_i_k.shape == (4 - i,) def per_param_callable(module_name, param_name): if 'baseline' in param_name or 'baseline' in module_name: return {"lr": 0.010, "betas": (0.95, 0.999)} else: return {"lr": 0.0012, "betas": (0.95, 0.999)} adam = optim.Adam(per_param_callable) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for step in range(n_steps): svi.step() loc_error = param_abs_error("loc_q", self.analytic_loc_n) log_sig_error = param_abs_error("log_sig_q", self.analytic_log_sig_n) if n_superfluous_top > 0 or n_superfluous_bottom > 0: superfluous_errors = [] for k in range(n_superfluous_top + n_superfluous_bottom): mean_0_error = torch.sum(torch.pow(pyro.param("mean_0_%d" % k), 2.0)) mean_1_error = torch.sum(torch.pow(pyro.param("mean_1_%d" % k), 2.0)) mean_2_error = torch.sum(torch.pow(pyro.param("mean_2_%d" % k), 2.0)) superfluous_error = torch.max(torch.max(mean_0_error, mean_1_error), mean_2_error) superfluous_errors.append(superfluous_error.detach().cpu().numpy()) if step % 500 == 0: logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) if n_superfluous_top > 0 or n_superfluous_bottom > 0: logger.debug("superfluous error: %.4f" % np.max(superfluous_errors)) assert_equal(0.0, loc_error, prec=0.04) assert_equal(0.0, log_sig_error, prec=0.05) if n_superfluous_top > 0 or n_superfluous_bottom > 0: assert_equal(0.0, np.max(superfluous_errors), prec=0.04)