def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): elbo_particle = 0. sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) # compute elbo for name, site in model_trace.nodes.items(): if name in self.block_names: continue if site["type"] == "sample": log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if name in self.block_names: continue if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle - log_prob_sum elbo_particles.append(elbo_particle) if is_vectorized: elbo_particles = elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) log_weights = (1. - self.alpha) * elbo_particles log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log( self.num_particles) elbo = log_mean_weight.sum().item() / (1. - self.alpha) loss = -elbo warn_if_nan(loss, "loss") return loss
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo_particles = [] surrogate_elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 tensor_holder = None # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle + log_prob_sum.detach() surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle - log_prob_sum.detach() if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum if not is_identically_zero(score_function_term): # link to the issue: https://github.com/pyro-ppl/pyro/issues/1222 raise NotImplementedError if not is_identically_zero(score_function_term): surrogate_elbo_particle = ( surrogate_elbo_particle + (self.alpha / (1.0 - self.alpha)) * log_prob_sum) if is_identically_zero(elbo_particle): if tensor_holder is not None: elbo_particle = torch.zeros_like(tensor_holder) surrogate_elbo_particle = torch.zeros_like(tensor_holder) else: # elbo_particle is not None if tensor_holder is None: tensor_holder = torch.zeros_like(elbo_particle) # change types of previous `elbo_particle`s for i in range(len(elbo_particles)): elbo_particles[i] = torch.zeros_like(tensor_holder) surrogate_elbo_particles[i] = torch.zeros_like( tensor_holder) elbo_particles.append(elbo_particle) surrogate_elbo_particles.append(surrogate_elbo_particle) if tensor_holder is None: return 0.0 if is_vectorized: elbo_particles = elbo_particles[0] surrogate_elbo_particles = surrogate_elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) surrogate_elbo_particles = torch.stack(surrogate_elbo_particles) log_weights = (1.0 - self.alpha) * elbo_particles log_mean_weight = torch.logsumexp( log_weights, dim=0, keepdim=True) - math.log(self.num_particles) elbo = log_mean_weight.sum().item() / (1.0 - self.alpha) # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particles, "requires_grad", False): normalized_weights = (log_weights - log_mean_weight).exp() surrogate_elbo = (normalized_weights * surrogate_elbo_particles ).sum() / self.num_particles surrogate_loss = -surrogate_elbo surrogate_loss.backward() loss = -elbo warn_if_nan(loss, "loss") return loss