def _dist_and_values(self): # XXX currently this whole object is very inefficient values_map, logits = collections.OrderedDict( ), collections.OrderedDict() for tr, logit in zip(self.trace_dist.exec_traces, self.trace_dist.log_weights): if isinstance(self.sites, str): value = tr.nodes[self.sites]["value"] else: value = {site: tr.nodes[site]["value"] for site in self.sites} if not torch.is_tensor(logit): logit = torch.tensor(logit) if torch.is_tensor(value): value_hash = hash(value.cpu().contiguous().numpy().tobytes()) elif isinstance(value, dict): value_hash = hash(self._dict_to_tuple(value)) else: value_hash = hash(value) if value_hash in logits: # Value has already been seen. logits[value_hash] = logsumexp(torch.stack( [logits[value_hash], logit]), dim=-1) else: logits[value_hash] = logit values_map[value_hash] = value logits = torch.stack(list(logits.values())).contiguous().view(-1) logits = logits - logsumexp(logits, dim=-1) d = dist.Categorical(logits=logits) return d, values_map
def test_ubersum_5(impl): # z {ij} <--- target # | # y {i} # | # x {} i, j, a, b, c = 2, 3, 6, 5, 4 x = torch.randn(a) y = torch.randn(a, b, i) z = torch.randn(b, c, i, j) actual, = impl('a,abi,bcij->cij', x, y, z, plates='ij', modulo_total=True) # contract plate j s1 = logsumexp(z, 1) assert s1.shape == (b, i, j) p1 = s1.sum(2) assert p1.shape == (b, i) q1 = z - s1.unsqueeze(-3) assert q1.shape == (b, c, i, j) # contract plate i x2 = y + p1 assert x2.shape == (a, b, i) s2 = logsumexp(x2, 1) assert s2.shape == (a, i) p2 = s2.sum(1) assert p2.shape == (a,) q2 = x2 - s2.unsqueeze(-2) assert q2.shape == (a, b, i) expected = opt_einsum.contract('a,a,abi,bcij->cij', x, p2, q2, q1, backend='pyro.ops.einsum.torch_log') assert_equal(actual, expected)
def test_ubersum_1(impl): # y {a} z {b} # \ / # x {} <--- target a, b, c, d, e = 2, 3, 4, 5, 6 x = torch.randn(c) y = torch.randn(c, d, a) z = torch.randn(e, c, b) actual, = impl('c,cda,ecb->', x, y, z, plates='ab', modulo_total=True) expected = logsumexp(x + logsumexp(y, -2).sum(-1) + logsumexp(z, -3).sum(-1), -1) assert_equal(actual, expected)
def test_ubersum_2(impl): # y {a} z {b} <--- target # \ / # x {} a, b, c, d, e = 2, 3, 4, 5, 6 x = torch.randn(c) y = torch.randn(c, d, a) z = torch.randn(e, c, b) actual, = impl('c,cda,ecb->b', x, y, z, batch_dims='ab', modulo_total=True) xyz = logsumexp(x + logsumexp(y, -2).sum(-1) + logsumexp(z, -3).sum(-1), -1) expected = xyz.expand(b) assert_equal(actual, expected)
def test_ubersum_3(impl): # z {b,c} # | # w {a} y {b} <--- target # \ / # x {} a, b, c, d, e = 2, 3, 4, 5, 6 w = torch.randn(a, e) x = torch.randn(d) y = torch.randn(b, d) z = torch.randn(b, c, d, e) (actual, ) = impl("ae,d,bd,bcde->be", w, x, y, z, plates="abc", modulo_total=True) yz = y.reshape(b, d, 1) + z.sum(-3) # eliminate c assert yz.shape == (b, d, e) yz = yz.sum(0) # eliminate b assert yz.shape == (d, e) wxyz = w.sum(0) + x.reshape(d, 1) + yz # eliminate a assert wxyz.shape == (d, e) wxyz = logsumexp(wxyz, 0) # eliminate d assert wxyz.shape == (e, ) expected = wxyz.expand(b, e) # broadcast to b assert_equal(actual, expected)
def test_ubersum_4(impl): # x,y {b} <--- target # | # {} a, b, c, d = 2, 3, 4, 5 x = torch.randn(a, b) y = torch.randn(d, b, c) actual, = impl('ab,dbc->dc', x, y, plates='d', modulo_total=True) x_b1 = logsumexp(x, 0).unsqueeze(-1) assert x_b1.shape == (b, 1) y_db1 = logsumexp(y, 2, keepdim=True) assert y_db1.shape == (d, b, 1) y_dbc = y_db1.sum(0) - y_db1 + y # inclusion-exclusion assert y_dbc.shape == (d, b, c) xy_dc = logsumexp(x_b1 + y_dbc, 1) assert xy_dc.shape == (d, c) expected = xy_dc assert_equal(actual, expected)
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): elbo_particle = 0. # compute elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_vectorized: log_prob_sum = site["log_prob"].detach().reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = torch_item(site["log_prob_sum"]) elbo_particle = elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] if is_vectorized: log_prob_sum = log_prob.detach().reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = torch_item(site["log_prob_sum"]) elbo_particle = elbo_particle - log_prob_sum elbo_particles.append(elbo_particle) if is_vectorized: elbo_particles = elbo_particles[0] else: elbo_particles = torch.tensor( elbo_particles) # no need to use .new*() here log_weights = (1. - self.alpha) * elbo_particles log_mean_weight = logsumexp(log_weights, dim=0) - math.log( self.num_particles) elbo = log_mean_weight.sum().item() / (1. - self.alpha) loss = -elbo warn_if_nan(loss, "loss") return loss
def get_normalized_weights(self, log_scale=False): """ Compute the normalized importance weights. """ if self.log_weights: log_w = torch.tensor(self.log_weights) log_w_norm = log_w - logsumexp(log_w, 0) return log_w_norm if log_scale else torch.exp(log_w_norm) else: warnings.warn( "The log_weights list is empty. There is nothing to normalize." )
def get_ESS(self): """ Compute (Importance Sampling) Effective Sample Size (ESS). """ if self.log_weights: log_w_norm = self.get_normalized_weights(log_scale=True) ess = torch.exp(-logsumexp(2 * log_w_norm, 0)) else: warnings.warn( "The log_weights list is empty, effective sample size is zero." ) ess = 0 return ess
def get_log_normalizer(self): """ Estimator of the normalizing constant of the target distribution. (mean of the unnormalized weights) """ # ensure list is not empty if self.log_weights: log_w = torch.tensor(self.log_weights) log_num_samples = torch.log(torch.tensor(self.num_samples * 1.)) return logsumexp(log_w - log_num_samples, 0) else: warnings.warn( "The log_weights list is empty, can not compute normalizing constant estimate." )
def evaluate_log_predictive_density(posterior_predictive, baseball_dataset): """ Evaluate the log probability density of observing the unseen data (season hits) given a model and empirical distribution over the parameters. """ _, test, player_names = train_test_split(baseball_dataset) at_bats_season, hits_season = test[:, 0], test[:, 1] test_eval = posterior_predictive.run(at_bats_season, hits_season) trace_log_pdf = [] for tr in test_eval.exec_traces: trace_log_pdf.append(tr.log_prob_sum()) # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $, # where $\theta^{i}$ are parameter samples from the model's posterior. posterior_pred_density = logsumexp(torch.stack(trace_log_pdf), dim=-1) - math.log(len(trace_log_pdf)) logging.info("\nLog posterior predictive density") logging.info("--------------------------------") logging.info("{:.4f}\n".format(posterior_pred_density))
def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.)): """ Reduce the log prob terms for the given ordinal: - taking log_sum_exp of factors in enum dims (i.e. adding up the probability terms). - summing up the dims within `max_plate_nesting`. (i.e. multiplying probs within independent batches). :param ordinal: node (ordinal) :param torch.Tensor agg_log_prob: aggregated `log_prob` terms from the downstream nodes. :return: `log_prob` with marginalized `plate` and `enum` dims. """ log_prob = sum(self._log_probs[ordinal]) + agg_log_prob for enum_dim in self._enum_dims[ordinal]: log_prob = logsumexp(log_prob, dim=enum_dim, keepdim=True) for marginal_dim in self._plate_dims[ordinal]: log_prob = log_prob.sum(dim=marginal_dim, keepdim=True) return log_prob
def sample(self, trace): z = { name: node["value"].detach() for name, node in self._iter_latent_nodes(trace) } potential_energy, z_grads = self._fetch_from_cache() # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r, r_flat = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \ else self._energy(z, r) # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample( "slicevar_exp_t={}".format(self._t), dist.Exponential(energy_current.new_tensor(1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_flat if self.use_multinomial_sampling: tree_weight = energy_current.new_zeros(()) else: tree_weight = energy_current.new_ones(()) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging for tree_depth in range(self._max_tree_depth + 1): direction = pyro.sample( "direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=torch.ones(1) * 0.5)) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left z_left_grads = new_tree.z_left_grads if new_tree.turning or new_tree.diverging: # stop doubling break if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample( "rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(new_tree_prob.new_tensor(0.), new_tree_prob.new_tensor(1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal self._cache(new_tree.z_proposal_pe, new_tree.z_proposal_grads) r_sum = r_sum + new_tree.r_sum if self._is_turning(r_left, r_right, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = logsumexp(torch.stack( [tree_weight, new_tree.weight]), dim=0) else: tree_weight = tree_weight + new_tree.weight if self._t < self._warmup_steps: accept_prob = new_tree.sum_accept_probs / new_tree.num_proposals self._adapter.step(self._t, z, accept_prob) if accepted: self._accept_cnt += 1 self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): if tree_depth == 0: return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current) # build the first half of tree half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth - 1, energy_current) z_proposal = half_tree.z_proposal z_proposal_pe = half_tree.z_proposal_pe z_proposal_grads = half_tree.z_proposal_grads # Check conditions to stop doubling. If we meet that condition, # there is no need to build the other tree. if half_tree.turning or half_tree.diverging: return half_tree # Else, build remaining half of tree. # If we are going to the right, start from the right leaf of the first half. if direction == 1: z = half_tree.z_right r = half_tree.r_right z_grads = half_tree.z_right_grads else: # otherwise, start from the left leaf of the first half z = half_tree.z_left r = half_tree.r_left z_grads = half_tree.z_left_grads other_half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth - 1, energy_current) if self.use_multinomial_sampling: tree_weight = logsumexp(torch.stack( [half_tree.weight, other_half_tree.weight]), dim=0) else: tree_weight = half_tree.weight + other_half_tree.weight sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs num_proposals = half_tree.num_proposals + other_half_tree.num_proposals r_sum = half_tree.r_sum + other_half_tree.r_sum # The probability of that proposal belongs to which half of tree # is computed based on the weights of each half. if self.use_multinomial_sampling: other_half_tree_prob = (other_half_tree.weight - tree_weight).exp() else: # For the special case that the weights of each half are both 0, # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0 else tree_weight.new_zeros( ())) is_other_half_tree = pyro.sample( "is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob)) if is_other_half_tree == 1: z_proposal = other_half_tree.z_proposal z_proposal_pe = other_half_tree.z_proposal_pe z_proposal_grads = other_half_tree.z_proposal_grads # leaves of the full tree are determined by the direction if direction == 1: z_left = half_tree.z_left r_left = half_tree.r_left z_left_grads = half_tree.z_left_grads z_right = other_half_tree.z_right r_right = other_half_tree.r_right z_right_grads = other_half_tree.z_right_grads else: z_left = other_half_tree.z_left r_left = other_half_tree.r_left z_left_grads = other_half_tree.z_left_grads z_right = half_tree.z_right r_right = half_tree.r_right z_right_grads = half_tree.z_right_grads # We already check if first half tree is turning. Now, we check # if the other half tree or full tree are turning. turning = other_half_tree.turning or self._is_turning( r_left, r_right, r_sum) # The divergence is checked by the second half tree (the first half is already checked). diverging = other_half_tree.diverging return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging, sum_accept_probs, num_proposals)
def _normalize(tensor, dims, plates): total = tensor for i, dim in enumerate(dims): if dim not in plates: total = logsumexp(total, i, keepdim=True) return tensor - total
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo_particles = [] surrogate_elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 tensor_holder = None # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_vectorized: log_prob_sum = site["log_prob"].reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = site["log_prob_sum"] elbo_particle = elbo_particle + log_prob_sum.detach() surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] if is_vectorized: log_prob_sum = log_prob.reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = site["log_prob_sum"] elbo_particle = elbo_particle - log_prob_sum.detach() if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum if not is_identically_zero(score_function_term): # link to the issue: https://github.com/uber/pyro/issues/1222 raise NotImplementedError if not is_identically_zero(score_function_term): surrogate_elbo_particle = ( surrogate_elbo_particle + (self.alpha / (1. - self.alpha)) * log_prob_sum) if is_identically_zero(elbo_particle): if tensor_holder is not None: elbo_particle = tensor_holder.new_zeros( tensor_holder.shape) surrogate_elbo_particle = tensor_holder.new_zeros( tensor_holder.shape) else: # elbo_particle is not None if tensor_holder is None: tensor_holder = elbo_particle.new_empty( elbo_particle.shape) # change types of previous `elbo_particle`s for i in range(len(elbo_particles)): elbo_particles[i] = tensor_holder.new_zeros( tensor_holder.shape) surrogate_elbo_particles[i] = tensor_holder.new_zeros( tensor_holder.shape) elbo_particles.append(elbo_particle) surrogate_elbo_particles.append(surrogate_elbo_particle) if tensor_holder is None: return 0. if is_vectorized: elbo_particles = elbo_particles[0] surrogate_elbo_particles = surrogate_elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) surrogate_elbo_particles = torch.stack(surrogate_elbo_particles) log_weights = (1. - self.alpha) * elbo_particles log_mean_weight = logsumexp(log_weights, dim=0) - math.log( self.num_particles) elbo = log_mean_weight.sum().item() / (1. - self.alpha) # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particles, 'requires_grad', False): normalized_weights = (log_weights - log_mean_weight).exp() surrogate_elbo = (normalized_weights * surrogate_elbo_particles ).sum() / self.num_particles surrogate_loss = -surrogate_elbo surrogate_loss.backward() loss = -elbo warn_if_nan(loss, "loss") return loss