def test_batch_log_pdf_mask(dist): if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'): pytest.skip('Batch pdf masking not supported for the distribution.') d = dist.pyro_dist for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) x = dist.get_test_data(idx) with xfail_if_not_implemented(): batch_pdf_shape = d.batch_shape(**dist_params) + (1, ) batch_pdf_shape_broadcasted = d.batch_shape(x, ** dist_params) + (1, ) zeros_mask = ng_zeros(1) # should be broadcasted to data dims ones_mask = ng_ones( batch_pdf_shape) # should be broadcasted to data dims half_mask = ng_ones(1) * 0.5 batch_log_pdf = d.batch_log_pdf(x, **dist_params) batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params) batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params) batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params) assert_equal(batch_log_pdf_ones_mask, batch_log_pdf) assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted)) assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
def model(): mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.tau0, -0.5)) bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent) x_dist = TransformedDistribution(dist.normal, bijector) pyro.observe("obs0", x_dist, self.data[0], ng_zeros(1), ng_ones(1)) pyro.observe("obs1", x_dist, self.data[1], ng_zeros(1), ng_ones(1)) return mu_latent
def obs_inner(i, _i, _x): for k in range(n_superfluous_top): pyro.sample("z_%d_%d" % (i, k), dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False)) pyro.observe("obs_%d" % i, dist.normal, _x, mu_latent, torch.pow(self.lam, -0.5)) for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): pyro.sample("z_%d_%d" % (i, k), dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
def test_categorical_batch_log_pdf_shape(one_hot): ps = ng_ones(3, 2, 4) / 4 if one_hot: x = ng_zeros(3, 2, 4) x[:, :, 0] = 1 else: x = ng_zeros(3, 2, 1) d = dist.Categorical(ps, one_hot=one_hot) assert d.batch_log_pdf(x).size() == (3, 2, 1)
def model(self, x): u_mu = ng_zeros(self.U_size) u_sigma = ng_ones(self.U_size) U = pyro.sample('u', dist.normal, u_mu, u_sigma) v_mu = ng_zeros(self.V_size) v_sigma = ng_ones(self.V_size) V = pyro.sample('v', dist.normal, v_mu, v_sigma) pyro.observe('x', dist.bernoulli, x, torch.matmul(U, torch.t(V)), # ng_ones(x.size(), type_as=x.data) )
def model(self, data): decoder = pyro.module('decoder', self.vae_decoder) z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20]) z = pyro.sample('latent', Normal(z_mean, z_std)) img = decoder.forward(z) pyro.sample('obs', Bernoulli(img), obs=data.view(-1, 784))
def model(self): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) # Setup hyperparameters for prior p(z) z_mu = ng_zeros([self.n_samples, self.n_latent]) z_sigma = ng_ones([self.n_samples, self.n_latent]) # sample from prior z = pyro.sample("latent", dist.normal, z_mu, z_sigma) # decode the latent code z z_adj = self.decoder(z) # Subsampling if self.subsampling: with pyro.iarange("data", self.n_subsample, subsample=self.sample()) as ind: pyro.observe('obs', dist.bernoulli, self.adj_labels.view(1, -1)[0][ind], z_adj.view(1, -1)[0][ind]) # Reweighting else: with pyro.iarange("data"): pyro.observe('obs', weighted_bernoulli, self.adj_labels.view(1, -1), z_adj.view(1, -1), weight=self.pos_weight)
def setUp(self): # normal-normal; known covariance self.lam0 = Variable(torch.Tensor([0.1, 0.1])) # precision of prior self.mu0 = Variable(torch.Tensor([0.0, 0.5])) # prior mean # known precision of observation noise self.lam = Variable(torch.Tensor([6.0, 4.0])) self.n_outer = 3 self.n_inner = 3 self.n_data = Variable(torch.Tensor([self.n_outer * self.n_inner])) self.data = [] self.sum_data = ng_zeros(2) for _out in range(self.n_outer): data_in = [] for _in in range(self.n_inner): data_in.append( Variable( torch.Tensor([-0.1, 0.3]) + torch.randn(2) / torch.sqrt(self.lam.data))) self.sum_data += data_in[-1] self.data.append(data_in) self.analytic_lam_n = self.lam0 + self.n_data.expand_as( self.lam) * self.lam self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n) self.analytic_mu_n = self.sum_data * (self.lam / self.analytic_lam_n) +\ self.mu0 * (self.lam0 / self.analytic_lam_n)
def model(self, data): decoder = pyro.module('decoder', self.decoder) # Normal prior if self.prior == 0: z_mu, z_sigma = ng_zeros([data.size(0), self.z_dim ]), ng_ones([data.size(0), self.z_dim]) if self.cuda_mode: z_mu, z_sigma = z_mu.cuda(), z_sigma.cuda() z = pyro.sample("latent", dist.normal, z_mu, z_sigma) elif self.prior == 1: z_mu, z_sigma = self.vampprior() z_mu_avg = torch.mean(z_mu, 0) z_sigma_square = z_sigma * z_sigma z_sigma_square_avg = torch.mean(z_sigma_square, 0) z_sigma_avg = torch.sqrt(z_sigma_square_avg) z_mu_avg = z_mu_avg.expand(data.size(0), z_mu_avg.size(0)) z_sigma_avg = z_sigma_avg.expand(data.size(0), z_sigma_avg.size(0)) z = pyro.sample("latent", dist.normal, z_mu_avg, z_sigma_avg) x1, x2, t, y = decoder.forward( z, data[:, :len(self.dataset.binfeats)], data[:, len(self.dataset.binfeats):(len(self.dataset.binfeats) + len(self.dataset.contfeats))], data[:, -2], data[:, -1])
def test_normal_shape(): mu = ng_zeros(3, 2) sigma = ng_ones(3, 2) d = dist.Normal(mu, sigma) assert d.batch_shape() == (3, ) assert d.event_shape() == (2, ) assert d.shape() == (3, 2) assert d.sample().size() == d.shape()
def test_normal_shape(): mu = ng_zeros(3, 2) sigma = ng_ones(3, 2) d = dist.Normal(mu, sigma) assert d.batch_shape() == (3,) assert d.event_shape() == (2,) assert d.shape() == (3, 2) assert d.sample().size() == d.shape()
def model(self, input_variable, target_variable, step): # register PyTorch module `decoder` with Pyro pyro.module("decoder_dense", self.decoder_dense) pyro.module("decoder_rnn", self.decoder_rnn) # setup hyperparameters for prior p(z) # the type_as ensures we get CUDA Tensors if x is on gpu z_mu = ng_zeros([self.num_layers, self.z_dim], type_as=target_variable.data) z_sigma = ng_ones([self.num_layers, self.z_dim], type_as=target_variable.data) # sample from prior # (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.normal, z_mu, z_sigma) # init vars target_length = target_variable.shape[0] decoder_input = dataset.to_onehot([[self.dataset.SOS_index]]) decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input decoder_outputs = np.ones((target_length)) decoder_hidden = self.decoder_dense(z) # # Teacher forcing for di in range(target_length): decoder_output, decoder_hidden = self.decoder_rnn( decoder_input, decoder_hidden) decoder_input = target_variable[di] if self.use_cuda: decoder_outputs[di] = np.argmax(decoder_output.cpu().data.numpy()) else: decoder_outputs[di] = np.argmax(decoder_output.data.numpy()) pyro.observe("obs_{}".format(di), dist.bernoulli, target_variable[di], decoder_output[0]) # ---------------------------------------------------------------- # prepare offer if self.use_cuda: offer = np.argmax(input_variable.cpu().data.numpy(), axis=1).astype(int) else: offer = np.argmax(input_variable.data.numpy(), axis=1).astype(int) # prepare answer if self.use_cuda: answer = np.argmax(target_variable.cpu().data.numpy(), axis=1).astype(int) else: answer = np.argmax(target_variable.data.numpy(), axis=1).astype(int) # prepare rnn rnn_response = list(map(int, decoder_outputs)) # print output if step % 10 == 0: print("---------------------------") print("Offer: ", dataset.to_phrase(offer)) print("Answer:", self.dataset.to_phrase(answer)) print("RNN:", self.dataset.to_phrase(rnn_response))
def reverse_sequences_torch(mini_batch, seq_lengths): reversed_mini_batch = ng_zeros(mini_batch.size(), type_as=mini_batch.data) for b in range(mini_batch.size(0)): T = seq_lengths[b] time_slice = np.arange(T - 1, -1, -1) time_slice = Variable(torch.cuda.LongTensor(time_slice)) if 'cuda' in mini_batch.data.type() \ else Variable(torch.LongTensor(time_slice)) reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice) reversed_mini_batch[b, 0:T, :] = reversed_sequence return reversed_mini_batch
def model(self, x, y): pyro.module('decoder', self.decoder) i, y = y U_mu = ng_zeros([self.u_dim, self.z_dim]) U_sigma = ng_ones([self.u_dim, self.z_dim]) U = pyro.sample('U', dist.normal, U_mu, U_sigma) V_mu = ng_zeros([self.v_dim, self.z_dim]) V_sigma = ng_ones([self.v_dim, self.z_dim]) V = pyro.sample('V', dist.normal, V_mu, V_sigma) z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data) z = pyro.sample('latent', dist.normal, U[i, :] * V[y, :], z_sigma) img_mu = self.decoder(z) pyro.sample('obs', dist.bernoulli, img_mu, obs=x.view(-1, 784))
def model(self, x): pyro.module('decoder', self.decoder) z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data) z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data) z = pyro.sample('latent', dist.normal, z_mu, z_sigma) img_mu = self.decoder(z) pyro.sample('obs', dist.bernoulli, img_mu, obs=x.view(-1, 784))
def test_batch_log_pdf_mask(dist): if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'): pytest.skip('Batch pdf masking not supported for the distribution.') d = dist.pyro_dist for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) x = dist.get_test_data(idx) with xfail_if_not_implemented(): batch_pdf_shape = d.batch_shape(**dist_params) + (1,) batch_pdf_shape_broadcasted = d.batch_shape(x, **dist_params) + (1,) zeros_mask = ng_zeros(1) # should be broadcasted to data dims ones_mask = ng_ones(batch_pdf_shape) # should be broadcasted to data dims half_mask = ng_ones(1) * 0.5 batch_log_pdf = d.batch_log_pdf(x, **dist_params) batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params) batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params) batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params) assert_equal(batch_log_pdf_ones_mask, batch_log_pdf) assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted)) assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
def expand_z_where(z_where): # Take a batch of three-vectors, and massages them into a batch of # 2x3 matrices with elements like so: # [s,x,y] -> [[s,0,x], # [0,s,y]] n = z_where.size(0) out = torch.cat((ng_zeros([1, 1]).type_as(z_where).expand(n, 1), z_where), 1) ix = Variable(expansion_indices) if z_where.is_cuda: ix = ix.cuda() out = torch.index_select(out, 1, ix) out = out.view(n, 2, 3) return out
def model(self, x): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) # setup hyperparameters for prior p(z) # the type_as ensures we get cuda Tensors if x is on gpu z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data) z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.normal, z_mu, z_sigma) # decode the latent code z mu_img = self.decoder.forward(z) # score against actual images pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
def model(self, x): # register decoder with pyro pyro.module("decoder", self.decoder) # setup hyper-parameters for prior p(z) z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data) z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data) # sample from prior (value will be sampled by guide when computing the ELBO), # decode the latent code z, # and score against actual frame z = pyro.sample("latent", dist.normal, z_mu, z_sigma) mu_frame = self.decoder.forward(z) pyro.observe("obs", dist.bernoulli, x.view(-1, self.frame), mu_frame)
def model(batch_size_outer=2, batch_size_inner=2): mu_latent = pyro.sample("mu_latent", dist.normal, ng_zeros(1), ng_ones(1)) def outer(i, x): pyro.map_data("map_inner_%d" % i, x, lambda _i, _x: inner(i, _i, _x), batch_size=batch_size_inner) def inner(i, _i, _x): pyro.sample("z_%d_%d" % (i, _i), dist.normal, mu_latent + _x, ng_ones(1)) pyro.map_data("map_outer", [[ng_ones(1)] * 2] * 2, lambda i, x: outer(i, x), batch_size=batch_size_outer) return mu_latent
def test_do_propagation(self): pyro.clear_param_store() def model(): z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1))) latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1)) flip = pyro.sample("flip", Bernoulli(latent_prob)) return flip sample_from_model = model() z_data = {"z": -10.0 * ng_ones(1)} # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0 sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))() assert eq(sample_from_model, ng_ones(1)) assert eq(sample_from_do_model, ng_zeros(1))
def model(self, x): raise NotImplementedError("don't use it plzz") pyro.module("decoder", self.decoders) # gerer params prior z_mu = ng_zeros([x.size(0), self.platent["dim"]], type_as=x.data) z_sigma = ng_ones([x.size(0), self.platent["dim"]], type_as=x.data) z = pyro.sample("latent",self.platent["dist"], z_mu, z_sigma) decoder_out = self.decoder.forward(z) if type(decoder_out)!=tuple: decoder_out = (decoder_out, ) if pinput["dist"] == dist.normal: decoder_out = list(decoder_out) decoder_out[1] = torch.exp(decoder_out[1]) decoder_out = tuple(decoder_out) pyro.sample("obs", self.pinput["dist"], obs=x.view(-1, self.pinput["dim"]), *decoder_out)
def _compute_elbo_non_reparam(guide_trace, guide_vec_md_nodes, # non_reparam_nodes, downstream_costs): # construct all the reinforce-like terms. # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance # XXX should the average baseline be in the param store as below? surrogate_elbo = 0.0 baseline_loss = 0.0 for node in non_reparam_nodes: guide_site = guide_trace.nodes[node] log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_cost = downstream_costs[node] baseline = 0.0 (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None assert(not (use_nn_baseline and use_baseline_value)), \ "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node, ng_zeros(1), tags="__tracegraph_elbo_internal_tag") avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old avg_downstream_cost_old.data = avg_downstream_cost_new.data # XXX copy_() ? baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline(detach_iterable(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # accumulate baseline loss baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum() guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"] # not scaled by subsampling if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value: if downstream_cost.size() != baseline.size(): raise ValueError("Expected baseline at site {} to be {} instead got {}".format( node, downstream_cost.size(), baseline.size())) downstream_cost = downstream_cost - baseline surrogate_elbo += (guide_log_pdf * downstream_cost.detach()).sum() return surrogate_elbo, baseline_loss
def test_subsample_gradient(trace_graph, reparameterized): pyro.clear_param_store() data_size = 2 subsample_size = 1 num_particles = 1000 precision = 0.333 data = dist.normal(ng_zeros(data_size), ng_ones(data_size)) def model(subsample_size): with pyro.iarange("data", len(data), subsample_size) as ind: x = data[ind] z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)), reparameterized=reparameterized)) pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x) def guide(subsample_size): mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True)) sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True)) with pyro.iarange("data", len(data), subsample_size) as ind: mu = mu[ind] sigma = sigma.expand(subsample_size) pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized)) optim = Adam({"lr": 0.1}) inference = SVI(model, guide, optim, loss="ELBO", trace_graph=trace_graph, num_particles=num_particles) # Compute gradients without subsampling. inference.loss_and_grads(model, guide, subsample_size=data_size) params = dict(pyro.get_param_store().named_parameters()) expected_grads = {name: param.grad.data.clone() for name, param in params.items()} zero_grads(params.values()) # Compute gradients with subsampling. inference.loss_and_grads(model, guide, subsample_size=subsample_size) actual_grads = {name: param.grad.data.clone() for name, param in params.items()} for name in sorted(params): print('\nexpected {} = {}'.format(name, expected_grads[name].cpu().numpy())) print('actual {} = {}'.format(name, actual_grads[name].cpu().numpy())) assert_equal(actual_grads, expected_grads, prec=precision)
def setUp(self): # normal-normal; known covariance self.lam0 = Variable(torch.Tensor([0.1, 0.1])) # precision of prior self.mu0 = Variable(torch.Tensor([0.0, 0.5])) # prior mean # known precision of observation noise self.lam = Variable(torch.Tensor([6.0, 4.0])) self.n_outer = 3 self.n_inner = 3 self.n_data = Variable(torch.Tensor([self.n_outer * self.n_inner])) self.data = [] self.sum_data = ng_zeros(2) for _out in range(self.n_outer): data_in = [] for _in in range(self.n_inner): data_in.append(Variable(torch.Tensor([-0.1, 0.3]) + torch.randn(2) / torch.sqrt(self.lam.data))) self.sum_data += data_in[-1] self.data.append(data_in) self.analytic_lam_n = self.lam0 + self.n_data.expand_as(self.lam) * self.lam self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n) self.analytic_mu_n = self.sum_data * (self.lam / self.analytic_lam_n) +\ self.mu0 * (self.lam0 / self.analytic_lam_n) self.verbose = True
def test_one_hot_categorical_batch_log_pdf_shape(): ps = ng_ones(3, 2, 4) / 4 x = ng_zeros(3, 2, 4) x[:, :, 0] = 1 d = dist.OneHotCategorical(ps) assert d.batch_log_pdf(x).size() == (3, 2, 1)
def test_normal_batch_log_pdf_shape(): mu = ng_zeros(3, 2) sigma = ng_ones(3, 2) x = ng_zeros(3, 2) d = dist.Normal(mu, sigma) assert d.batch_log_pdf(x).size() == (3, 1)
def test_categorical_batch_log_pdf_shape(): ps = ng_ones(3, 2, 4) / 4 x = ng_zeros(3, 2, 1) d = dist.Categorical(ps) assert d.batch_log_pdf(x).size() == (3, 2, 1)
def model(subsample_size): with pyro.iarange("data", len(data), subsample_size) as ind: x = data[ind] z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)), reparameterized=reparameterized)) pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x)
def model_dup(): pyro.param("mu_q", Variable(torch.ones(1), requires_grad=True)) pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
def guide(): p = pyro.param("p", Variable(torch.ones(1), requires_grad=True)) pyro.sample("mu_q", dist.normal, ng_zeros(1), p) pyro.sample("mu_q_2", dist.normal, ng_zeros(1), p)
def model(): mu_latent = pyro.sample("mu_latent", dist.Normal(ng_zeros(1), ng_ones(1), reparameterized=reparameterized)) pyro.observe('obs', dist.normal, Variable(torch.Tensor([0.23])), mu_latent, ng_ones(1)) return mu_latent
def model_obs_dup(): pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1)) pyro.observe("mu_q", dist.normal, ng_zeros(1), ng_ones(1), ng_zeros(1))
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # get info regarding rao-blackwellization of vectorized map_data guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"] model_vec_md_info = model_trace.graph["vectorized_map_data_info"] guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition'] model_vec_md_condition = model_vec_md_info['rao-blackwellization-condition'] do_vec_rb = guide_vec_md_condition and model_vec_md_condition if not do_vec_rb: warnings.warn( "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. " "Falling back to higher-variance gradient estimator. " "Try to avoid these issues in your model and guide:\n{}".format("\n".join( guide_vec_md_info["warnings"] | model_vec_md_info["warnings"]))) guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set() model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set() # have the trace compute all the individual (batch) log pdf terms # so that they are available below guide_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in guide_vec_md_nodes) guide_trace.log_pdf() model_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in model_vec_md_nodes) model_trace.log_pdf() # prepare a list of all the cost nodes, each of which is +- log_pdf cost_nodes = [] non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": if model_site["is_observed"]: cost_nodes.append(CostNode(model_site["log_pdf"], True)) else: # cost node from model sample cost_nodes.append(CostNode(model_site["log_pdf"], True)) # cost node from guide sample guide_site = guide_trace.nodes[name] zero_expectation = name in non_reparam_nodes cost_nodes.append(CostNode(-guide_site["log_pdf"], not zero_expectation)) # compute the elbo; if all stochastic nodes are reparameterizable, we're done # this bit is never differentiated: it's here for getting an estimate of the elbo itself elbo = torch_data_sum(sum(c.cost for c in cost_nodes)) # compute the surrogate elbo, removing terms whose gradient is zero # this is the bit that's actually differentiated # XXX should the user be able to control if these terms are included? surrogate_elbo = sum(c.cost for c in cost_nodes if c.nonzero_expectation) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: # recursively compute downstream cost nodes for all sample sites in model and guide # (even though ultimately just need for non-reparameterizable sample sites) # 1. downstream costs used for rao-blackwellization # 2. model observe sites (as well as terms that arise from the model and guide having different # dependency structures) are taken care of via 'children_in_model' below topo_sort_guide_nodes = list(reversed(list(networkx.topological_sort(guide_trace)))) topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample"] downstream_guide_cost_nodes = {} downstream_costs = {} for node in topo_sort_guide_nodes: node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \ guide_trace.nodes[node][node_log_pdf_key] nodes_included_in_sum = set([node]) downstream_guide_cost_nodes[node] = set([node]) for child in guide_trace.successors(node): child_cost_nodes = downstream_guide_cost_nodes[child] downstream_guide_cost_nodes[node].update(child_cost_nodes) if nodes_included_in_sum.isdisjoint(child_cost_nodes): # avoid duplicates if node_log_pdf_key == 'log_pdf': downstream_costs[node] += downstream_costs[child].sum() else: downstream_costs[node] += downstream_costs[child] nodes_included_in_sum.update(child_cost_nodes) missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum # include terms we missed because we had to avoid duplicates for missing_node in missing_downstream_costs: mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf' if node_log_pdf_key == 'log_pdf': downstream_costs[node] += (model_trace.nodes[missing_node][mn_log_pdf_key] - guide_trace.nodes[missing_node][mn_log_pdf_key]).sum() else: downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \ guide_trace.nodes[missing_node][mn_log_pdf_key] # finish assembling complete downstream costs # (the above computation may be missing terms from model) # XXX can we cache some of the sums over children_in_model to make things more efficient? for site in non_reparam_nodes: children_in_model = set() for node in downstream_guide_cost_nodes[site]: children_in_model.update(model_trace.successors(node)) # remove terms accounted for above children_in_model.difference_update(downstream_guide_cost_nodes[site]) for child in children_in_model: child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf' site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf' assert (model_trace.nodes[child]["type"] == "sample") if site_log_pdf_key == 'log_pdf': downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key].sum() else: downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key] # construct all the reinforce-like terms. # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance # XXX should the average baseline be in the param store as below? elbo_reinforce_terms = 0.0 for node in non_reparam_nodes: guide_site = guide_trace.nodes[node] log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_cost = downstream_costs[node] baseline = 0.0 (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None assert(not (use_nn_baseline and use_baseline_value)), \ "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node, ng_zeros(1), tags="__tracegraph_elbo_internal_tag") avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old avg_downstream_cost_old.data = avg_downstream_cost_new.data # XXX copy_() ? baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline(detach_iterable(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # accumulate baseline loss baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum() guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"] # not scaled by subsampling if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value: if downstream_cost.size() != baseline.size(): raise ValueError("Expected baseline at site {} to be {} instead got {}".format( node, downstream_cost.size(), baseline.size())) downstream_cost = downstream_cost - baseline elbo_reinforce_terms += (guide_log_pdf * downstream_cost.detach()).sum() surrogate_elbo += elbo_reinforce_terms # collect parameters to train from model and guide trainable_params = set(site["value"] for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "param") if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) pyro.get_param_store().mark_params_active(trainable_params) loss = -elbo return weight * loss
def ng_zeros(self, *args, **kwargs): t = ng_zeros(*args, **kwargs) if self.use_cuda: t = t.cuda() return t
def model(data): latent = named.Object("latent") latent.z.sample_(dist.normal, ng_zeros(1), ng_ones(1)) model_recurse(data, latent)
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated. """ elbo = 0.0 for weight, model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): # get info regarding rao-blackwellization of vectorized map_data guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"] model_vec_md_info = model_trace.graph["vectorized_map_data_info"] guide_vec_md_condition = guide_vec_md_info[ 'rao-blackwellization-condition'] model_vec_md_condition = model_vec_md_info[ 'rao-blackwellization-condition'] do_vec_rb = guide_vec_md_condition and model_vec_md_condition if not do_vec_rb: warnings.warn( "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. " "Falling back to higher-variance gradient estimator. " "Try to avoid these issues in your model and guide:\n{}". format("\n".join(guide_vec_md_info["warnings"] | model_vec_md_info["warnings"]))) guide_vec_md_nodes = guide_vec_md_info[ 'nodes'] if do_vec_rb else set() model_vec_md_nodes = model_vec_md_info[ 'nodes'] if do_vec_rb else set() # have the trace compute all the individual (batch) log pdf terms # so that they are available below guide_trace.compute_batch_log_pdf( site_filter=lambda name, site: name in guide_vec_md_nodes) guide_trace.log_pdf() model_trace.compute_batch_log_pdf( site_filter=lambda name, site: name in model_vec_md_nodes) model_trace.log_pdf() # prepare a list of all the cost nodes, each of which is +- log_pdf cost_nodes = [] non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) for site in model_trace.nodes.keys(): model_trace_site = model_trace.nodes[site] log_pdf_key = 'batch_log_pdf' if site in model_vec_md_nodes else 'log_pdf' if model_trace_site["type"] == "sample": if model_trace_site["is_observed"]: cost_node = (model_trace_site[log_pdf_key], True) cost_nodes.append(cost_node) else: # cost node from model sample cost_node1 = (model_trace_site[log_pdf_key], True) # cost node from guide sample zero_expectation = site in non_reparam_nodes cost_node2 = (-guide_trace.nodes[site][log_pdf_key], not zero_expectation) cost_nodes.extend([cost_node1, cost_node2]) elbo_particle = 0.0 surrogate_elbo_particle = 0.0 baseline_loss_particle = 0.0 elbo_reinforce_terms_particle = 0.0 elbo_no_zero_expectation_terms_particle = 0.0 # compute the elbo; if all stochastic nodes are reparameterizable, we're done # this bit is never differentiated: it's here for getting an estimate of the elbo itself for cost_node in cost_nodes: elbo_particle += cost_node[0].sum() elbo += weight * torch_data_sum(elbo_particle) # compute the elbo, removing terms whose gradient is zero # this is the bit that's actually differentiated # XXX should the user be able to control if these terms are included? for cost_node in cost_nodes: if cost_node[1]: elbo_no_zero_expectation_terms_particle += cost_node[ 0].sum() surrogate_elbo_particle += weight * elbo_no_zero_expectation_terms_particle # the following computations are only necessary if we have non-reparameterizable nodes if len(non_reparam_nodes) > 0: # recursively compute downstream cost nodes for all sample sites in model and guide # (even though ultimately just need for non-reparameterizable sample sites) # 1. downstream costs used for rao-blackwellization # 2. model observe sites (as well as terms that arise from the model and guide having different # dependency structures) are taken care of via 'children_in_model' below topo_sort_guide_nodes = list( reversed(list(networkx.topological_sort(guide_trace)))) topo_sort_guide_nodes = [ x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample" ] downstream_guide_cost_nodes = {} downstream_costs = {} for node in topo_sort_guide_nodes: node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \ guide_trace.nodes[node][node_log_pdf_key] nodes_included_in_sum = set([node]) downstream_guide_cost_nodes[node] = set([node]) for child in guide_trace.successors(node): child_cost_nodes = downstream_guide_cost_nodes[child] downstream_guide_cost_nodes[node].update( child_cost_nodes) if nodes_included_in_sum.isdisjoint( child_cost_nodes): # avoid duplicates if node_log_pdf_key == 'log_pdf': downstream_costs[node] += downstream_costs[ child].sum() else: downstream_costs[node] += downstream_costs[ child] nodes_included_in_sum.update(child_cost_nodes) missing_downstream_costs = downstream_guide_cost_nodes[ node] - nodes_included_in_sum # include terms we missed because we had to avoid duplicates for missing_node in missing_downstream_costs: mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf' if node_log_pdf_key == 'log_pdf': downstream_costs[node] += ( model_trace.nodes[missing_node][mn_log_pdf_key] - guide_trace.nodes[missing_node][mn_log_pdf_key] ).sum() else: downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \ guide_trace.nodes[missing_node][mn_log_pdf_key] # finish assembling complete downstream costs # (the above computation may be missing terms from model) # XXX can we cache some of the sums over children_in_model to make things more efficient? for site in non_reparam_nodes: children_in_model = set() for node in downstream_guide_cost_nodes[site]: children_in_model.update(model_trace.successors(node)) # remove terms accounted for above children_in_model.difference_update( downstream_guide_cost_nodes[site]) for child in children_in_model: child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf' site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf' assert (model_trace.nodes[child]["type"] == "sample") if site_log_pdf_key == 'log_pdf': downstream_costs[site] += model_trace.nodes[child][ child_log_pdf_key].sum() else: downstream_costs[site] += model_trace.nodes[child][ child_log_pdf_key] # construct all the reinforce-like terms. # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance # XXX should the average baseline be in the param store as below? # for extracting baseline options from site["baseline"] # XXX default for baseline_beta currently set here def get_baseline_options(site_baseline): options_dict = site_baseline.copy() options_tuple = (options_dict.pop('nn_baseline', None), options_dict.pop('nn_baseline_input', None), options_dict.pop( 'use_decaying_avg_baseline', False), options_dict.pop('baseline_beta', 0.90), options_dict.pop('baseline_value', None)) if options_dict: raise ValueError( "Unrecognized baseline options: {}".format( options_dict.keys())) return options_tuple baseline_loss_particle = 0.0 for node in non_reparam_nodes: log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_cost = downstream_costs[node] baseline = 0.0 (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value) = get_baseline_options( guide_trace.nodes[node]["baseline"]) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None assert(not (use_nn_baseline and use_baseline_value)), \ "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: avg_downstream_cost_old = pyro.param( "__baseline_avg_downstream_cost_" + node, ng_zeros(1), tags="__tracegraph_elbo_internal_tag") avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old avg_downstream_cost_old.data = avg_downstream_cost_new.data # XXX copy_() ? baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline( detach_iterable(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # construct baseline loss baseline_loss = torch.pow( downstream_cost.detach() - baseline, 2.0).sum() baseline_loss_particle += weight * baseline_loss if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value: if downstream_cost.size() != baseline.size(): raise ValueError( "Expected baseline at site {} to be {} instead got {}" .format(node, downstream_cost.size(), baseline.size())) elbo_reinforce_terms_particle += ( guide_trace.nodes[node][log_pdf_key] * (downstream_cost - baseline).detach()).sum() else: elbo_reinforce_terms_particle += ( guide_trace.nodes[node][log_pdf_key] * downstream_cost.detach()).sum() surrogate_elbo_particle += weight * elbo_reinforce_terms_particle torch_backward(baseline_loss_particle) # collect parameters to train from model and guide trainable_params = set(site["value"] for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "param") surrogate_loss_particle = -surrogate_elbo_particle if trainable_params: torch_backward(surrogate_loss_particle) # mark all params seen in trace as active so that gradient steps are taken downstream pyro.get_param_store().mark_params_active(trainable_params) loss = -elbo return loss
def model(): pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
def model(self, data): decoder = pyro.module('decoder', self.vae_decoder) z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20]) z = pyro.sample('latent', Normal(z_mean, z_std)) img = decoder.forward(z) pyro.observe('obs', Bernoulli(img), data.view(-1, 784))