def __call__(self, model, guide, *args, **kwargs): """ Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters. """ guide_trace, model_trace = self._get_traces(model, guide, args, kwargs) # Extract observations and posterior predictive samples. data = OrderedDict() samples = OrderedDict() for name, site in model_trace.nodes.items(): if site["type"] == "sample" and site["is_observed"]: data[name] = site["infer"]["obs"] samples[name] = site["value"] assert list(data.keys()) == list(samples.keys()) if not data: raise ValueError("Found no observations") # Compute energy distance from mean average error and generalized entropy. squared_error = [] # E[ (X - x)^2 ] squared_entropy = [] # E[ (X - X')^2 ] prototype = next(iter(data.values())) pairs = prototype.new_ones(self.num_particles, self.num_particles).tril(-1).nonzero() for name, obs in data.items(): sample = samples[name] scale = model_trace.nodes[name]["scale"] mask = model_trace.nodes[name]["mask"] # Flatten to subshapes of (num_particles, batch_size, event_size). event_dim = model_trace.nodes[name]["fn"].event_dim batch_shape = obs.shape[:obs.dim() - event_dim] event_shape = obs.shape[obs.dim() - event_dim:] if getattr(scale, 'shape', ()): scale = scale.expand(batch_shape).reshape(-1) if getattr(mask, 'shape', ()): mask = mask.expand(batch_shape).reshape(-1) obs = obs.reshape(batch_shape.numel(), event_shape.numel()) sample = sample.reshape(self.num_particles, batch_shape.numel(), event_shape.numel()) squared_error.append(_squared_error(sample, obs, scale, mask)) squared_entropy.append( _squared_error(*sample[pairs].unbind(1), scale, mask)) squared_error = reduce(operator.add, squared_error) squared_entropy = reduce(operator.add, squared_entropy) error = self._pow(squared_error).mean() # E[ ||X-x||^beta ] entropy = self._pow(squared_entropy).mean() # E[ ||X-X'||^beta ] energy = error - 0.5 * entropy # Compute prior. log_prior = 0 if self.prior_scale > 0: for site in model_trace.nodes.values(): if site["type"] == "sample" and not site["is_observed"]: log_prior = log_prior + site["log_prob_sum"] # Compute final loss. loss = energy - self.prior_scale * log_prior warn_if_nan(loss, "loss") return loss
def loss_and_grads(self, grads, batch, *args, **kwargs): """ :returns: an estimate of the loss (expectation over p(x, y) of -log q(x, y) ) - where p is the model and q is the guide :rtype: float If a batch is provided, the loss is estimated using these traces Otherwise, a fresh batch is generated from the model. If grads is True, will also call `backward` on loss. `args` and `kwargs` are passed to the model and guide. """ if batch is None: indices = np.random.choice(len(self.simulations), size = self.training_batch_size, replace = False) batch = [self.simulations[i] for i in indices] batch_size = self.training_batch_size else: batch_size = len(batch) # Collect all cross matched guide traces with poutine.trace(param_only=True) as particle_param_capture: guide_traces = [] for i in range(batch_size): # model_x: True model against which we contrast the rest model_x_trace = batch[i] guide_traces.append([]) for j in range(batch_size): # model_z: Contrasting model parameters model_z_trace = batch[j] # Evaluate matched guide guide_trace = self._get_matched_cross_trace( model_x_trace, model_z_trace, *args, **kwargs) guide_traces[-1].append(guide_trace) loss = torch.tensor(0.) # Calculate losses per site for site_name in self.site_names: for i in range(batch_size): model_x_trace = batch[i] log_prob_priors = [] for j in range(batch_size): model_z_trace = batch[j] log_prob_prior = ( model_x_trace.nodes[site_name]['fn'].log_prob( model_z_trace.nodes[site_name]['value'])) log_prob_priors.append(log_prob_prior.unsqueeze(0)) log_prob_priors = torch.cat(log_prob_priors, 0) guide_losses = torch.cat( [self._differentiable_loss_particle( guide_trace, site_name = site_name).unsqueeze(0) for guide_trace in guide_traces[i]], 0) f_phis = guide_losses + log_prob_priors r = -torch.log_softmax(-f_phis, 0) particle_loss = r[i].sum()/batch_size loss += particle_loss warn_if_nan(loss, "loss") if grads: guide_params = set(site["value"].unconstrained() for site in particle_param_capture.trace.nodes.values()) guide_params = list(guide_params) torch.autograd.set_detect_anomaly(True) guide_grads = torch.autograd.grad(loss, guide_params, allow_unused=True, retain_graph=True) for guide_grad, guide_param in zip(guide_grads, guide_params): if guide_param.grad is None: guide_param.grad = guide_grad else: if guide_grad is not None: guide_param.grad = guide_param.grad + guide_grad return torch_item(loss)
def compute_marginals_persistent_bp(exists_logits, assign_logits, bp_iters, bp_momentum=0.5): """ This implements approximate inference of pairwise marginals via loopy belief propagation, adapting the approach of [1], [2]. See :class:`MarginalAssignmentPersistent` for args and problem description. [1] Jason L. Williams, Roslyn A. Lau (2014) Approximate evaluation of marginal association probabilities with belief propagation https://arxiv.org/abs/1209.6299 [2] Ryan Turner, Steven Bottone, Bhargav Avasarala (2014) A Complete Variational Tracker https://papers.nips.cc/paper/5572-a-complete-variational-tracker.pdf """ # This implements forward-backward message passing among three sets of variables: # # a[t,j] ~ Categorical(num_objects + 1), detection -> object assignment # b[t,i] ~ Categorical(num_detections + 1), object -> detection assignment # e[i] ~ Bernonulli, whether each object exists # # Only assign = a and exists = e are returned. assert 0 <= bp_momentum < 1, bp_momentum old, new = bp_momentum, 1 - bp_momentum num_frames, num_detections, num_objects = assign_logits.shape message_b_to_a = assign_logits.new_zeros(num_frames, num_detections, num_objects) message_a_to_b = assign_logits.new_zeros(num_frames, num_detections, num_objects) message_b_to_e = assign_logits.new_zeros(num_frames, num_objects) message_e_to_b = assign_logits.new_zeros(num_frames, num_objects) for i in range(bp_iters): odds_a = (assign_logits + message_b_to_a).exp() message_a_to_b = (old * message_a_to_b + new * (assign_logits - (odds_a.sum(2, True) - odds_a).log1p())) message_b_to_e = (old * message_b_to_e + new * message_a_to_b.exp().sum(1).log1p()) message_e_to_b = ( old * message_e_to_b + new * (exists_logits + message_b_to_e.sum(0) - message_b_to_e)) odds_b = message_a_to_b.exp() message_b_to_a = (old * message_b_to_a - new * ((-message_e_to_b).exp().unsqueeze(1) + (1 + odds_b.sum(1, True) - odds_b)).log()) warn_if_nan(message_a_to_b, 'message_a_to_b iter {}'.format(i)) warn_if_nan(message_b_to_e, 'message_b_to_e iter {}'.format(i)) warn_if_nan(message_e_to_b, 'message_e_to_b iter {}'.format(i)) warn_if_nan(message_b_to_a, 'message_b_to_a iter {}'.format(i)) # Convert from probs to logits. exists = exists_logits + message_b_to_e.sum(0) assign = assign_logits + message_b_to_a warn_if_nan(exists, 'exists') warn_if_nan(assign, 'assign') return exists, assign
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo_particles = [] surrogate_elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 tensor_holder = None # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_vectorized: log_prob_sum = site["log_prob"].reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = site["log_prob_sum"] elbo_particle = elbo_particle + log_prob_sum.detach() surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] if is_vectorized: log_prob_sum = log_prob.reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = site["log_prob_sum"] elbo_particle = elbo_particle - log_prob_sum.detach() if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum if not is_identically_zero(score_function_term): # link to the issue: https://github.com/uber/pyro/issues/1222 raise NotImplementedError # if not is_identically_zero(score_function_term): # surrogate_elbo_particle = surrogate_elbo_particle if is_identically_zero(elbo_particle): if tensor_holder is not None: elbo_particle = tensor_holder.new_zeros( tensor_holder.shape) surrogate_elbo_particle = tensor_holder.new_zeros( tensor_holder.shape) else: # elbo_particle is not None if tensor_holder is None: tensor_holder = elbo_particle.new_empty( elbo_particle.shape) # change types of previous `elbo_particle`s for i in range(len(elbo_particles)): elbo_particles[i] = tensor_holder.new_zeros( tensor_holder.shape) surrogate_elbo_particles[i] = tensor_holder.new_zeros( tensor_holder.shape) elbo_particles.append(elbo_particle) surrogate_elbo_particles.append(surrogate_elbo_particle) if tensor_holder is None: return 0. if is_vectorized: elbo_particles = elbo_particles[0] surrogate_elbo_particles = surrogate_elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) surrogate_elbo_particles = torch.stack(surrogate_elbo_particles) log_weights = elbo_particles log_mean_weight = log_sum_exp(elbo_particles, dim=0) - math.log( self.num_particles) elbo = log_mean_weight.sum().item() # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particles, 'requires_grad', False): normalized_weights = (log_weights - log_mean_weight).exp() surrogate_elbo = (normalized_weights * surrogate_elbo_particles ).sum() / self.num_particles surrogate_loss = -surrogate_elbo surrogate_loss.backward() loss = -elbo warn_if_nan(loss, "loss") return loss
def _loss(self, model, guide, args, kwargs): """ :returns: returns model loss and guide loss :rtype: float, float Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi). Performs backward as appropriate on both, over the specified number of particles. """ wake_theta_loss = torch.tensor(100.) if self.model_has_params or self.insomnia > 0.: # compute quantities for wake theta and wake phi log_joints = [] log_qs = [] for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): log_joint = 0. log_q = 0. for _, site in model_trace.nodes.items(): if site["type"] == "sample": if self.vectorize_particles: log_p_site = site["log_prob"].reshape( self.num_particles, -1).sum(-1) else: log_p_site = site["log_prob_sum"] log_joint = log_joint + log_p_site for _, site in guide_trace.nodes.items(): if site["type"] == "sample": if self.vectorize_particles: log_q_site = site["log_prob"].reshape( self.num_particles, -1).sum(-1) else: log_q_site = site["log_prob_sum"] log_q = log_q + log_q_site log_joints.append(log_joint) log_qs.append(log_q) log_joints = log_joints[ 0] if self.vectorize_particles else torch.stack(log_joints) log_qs = log_qs[0] if self.vectorize_particles else torch.stack( log_qs) log_weights = log_joints - log_qs.detach() # compute wake theta loss log_sum_weight = torch.logsumexp(log_weights, dim=0) wake_theta_loss = -(log_sum_weight - math.log(self.num_particles)).sum() warn_if_nan(wake_theta_loss, "wake theta loss") if self.insomnia > 0: # compute wake phi loss normalised_weights = (log_weights - log_sum_weight).exp().detach() wake_phi_loss = -(normalised_weights * log_qs).sum() warn_if_nan(wake_phi_loss, "wake phi loss") if self.insomnia < 1: # compute sleep phi loss _model = pyro.poutine.uncondition(model) _guide = guide _log_q = 0. if self.vectorize_particles: if self.max_plate_nesting == float('inf'): self._guess_max_plate_nesting(_model, _guide, args, kwargs) _model = self._vectorized_num_sleep_particles(_model) _guide = self._vectorized_num_sleep_particles(guide) for _ in range(1 if self.vectorize_particles else self. num_sleep_particles): _model_trace = poutine.trace(_model).get_trace(*args, **kwargs) _model_trace.detach_() _guide_trace = self._get_matched_trace(_model_trace, _guide, args, kwargs) _log_q += _guide_trace.log_prob_sum() sleep_phi_loss = -_log_q / self.num_sleep_particles warn_if_nan(sleep_phi_loss, "sleep phi loss") # compute phi loss phi_loss = sleep_phi_loss if self.insomnia == 0 \ else wake_phi_loss if self.insomnia == 1 \ else self.insomnia * wake_phi_loss + (1. - self.insomnia) * sleep_phi_loss return wake_theta_loss, phi_loss
def _differentiable_loss_parts(self, model, guide, *args, **kwargs): all_model_samples = defaultdict(list) all_guide_samples = defaultdict(list) loglikelihood = 0.0 penalty = 0.0 for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): if self.vectorize_particles: model_trace_independent = poutine.trace( self._vectorized_num_particles(model)).get_trace( *args, **kwargs) else: model_trace_independent = poutine.trace( model, graph_type='flat').get_trace(*args, **kwargs) loglikelihood_particle = 0.0 for name, model_site in model_trace.nodes.items(): if model_site['type'] == 'sample': if name in guide_trace and not model_site['is_observed']: guide_site = guide_trace.nodes[name] independent_model_site = model_trace_independent.nodes[ name] if not independent_model_site["fn"].has_rsample: raise ValueError( "Model site {} is not reparameterizable". format(name)) if not guide_site["fn"].has_rsample: raise ValueError( "Guide site {} is not reparameterizable". format(name)) particle_dim = -self.max_plate_nesting - independent_model_site[ "fn"].event_dim model_samples = independent_model_site['value'] guide_samples = guide_site['value'] if self.vectorize_particles: model_samples = model_samples.transpose( -model_samples.dim(), particle_dim) model_samples = model_samples.view( model_samples.shape[0], -1) guide_samples = guide_samples.transpose( -guide_samples.dim(), particle_dim) guide_samples = guide_samples.view( guide_samples.shape[0], -1) else: model_samples = model_samples.view(1, -1) guide_samples = guide_samples.view(1, -1) all_model_samples[name].append(model_samples) all_guide_samples[name].append(guide_samples) else: loglikelihood_particle = loglikelihood_particle + model_site[ 'log_prob_sum'] loglikelihood = loglikelihood_particle / self.num_particles + loglikelihood for name in all_model_samples.keys(): all_model_samples[name] = torch.cat(all_model_samples[name]) all_guide_samples[name] = torch.cat(all_guide_samples[name]) divergence = _compute_mmd(all_model_samples[name], all_guide_samples[name], kernel=self._kernel[name]) penalty = self._mmd_scale[name] * divergence + penalty warn_if_nan(loglikelihood, "loglikelihood") warn_if_nan(penalty, "penalty") return loglikelihood, penalty
def _quantized_model(self): """ Quantized vectorized model used for parallel-scan enumerated inference. This method is called only outside particle_plate. """ C = len(self.compartments) T = self.duration Q = self.num_quant_bins R_shape = getattr(self.population, "shape", ()) # Region shape. # Sample global parameters and auxiliary variables. params = self.global_model() auxiliary, non_compartmental = self._sample_auxiliary() # Manually enumerate. curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population, num_quant_bins=self.num_quant_bins) curr = OrderedDict(zip(self.compartments, curr.unbind(0))) logp = OrderedDict(zip(self.compartments, logp.unbind(0))) curr.update(non_compartmental) # Truncate final value from the right then pad initial value onto the left. init = self.initialize(params) prev = {} for name, value in init.items(): if name in self.compartments: if isinstance(value, torch.Tensor): value = value[ ..., None] # Because curr is enumerated on the right. prev[name] = cat2(value, curr[name][:-1], dim=-3 if self.is_regional else -2) else: # non-compartmental prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim()) # Reshape to support broadcasting, similar to EnumMessenger. def enum_reshape(tensor, position): assert tensor.size(-1) == Q assert tensor.dim() <= self.max_plate_nesting + 2 tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1)) shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2)) shape.extend(tensor.shape[1:]) return tensor.reshape(shape) for e, name in enumerate(self.compartments): curr[name] = enum_reshape(curr[name], e) logp[name] = enum_reshape(logp[name], e) prev[name] = enum_reshape(prev[name], e + C) # Enable approximate inference by using aux as a non-enumerated proxy # for enumerated compartment values. for name in self.approximate: aux = auxiliary[self.compartments.index(name)] curr[name + "_approx"] = aux prev[name + "_approx"] = cat2(init[name], aux[:-1], dim=-2 if self.is_regional else -1) # Record transition factors. with poutine.block(), poutine.trace() as tr: with self.time_plate: t = slice(0, T, 1) # Used to slice data tensors. self._transition_bwd(params, prev, curr, t) tr.trace.compute_log_prob() for name, site in tr.trace.nodes.items(): if site["type"] == "sample": log_prob = site["log_prob"] if log_prob.dim() <= self.max_plate_nesting: # Not enumerated. pyro.factor("transition_" + name, site["log_prob_sum"]) continue if self.is_regional and log_prob.shape[-1:] != R_shape: # Poor man's tensor variable elimination. log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0] logp[name] = site["log_prob"] # Manually perform variable elimination. logp = reduce(operator.add, logp.values()) logp = logp.reshape(Q**C, Q**C, T, -1) # prev, curr, T, batch logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr logp = pyro.distributions.hmm._sequential_logmatmulexp( logp) # batch, prev, curr logp = logp.reshape(-1, Q**C * Q**C).logsumexp(-1).sum() warn_if_nan(logp) pyro.factor("transition", logp) self._clear_plates()
def differentiable_loss(self, model, guide, *args, **kwargs): loss, surrogate_loss = self.loss_and_surrogate_loss(model, guide, *args, **kwargs) warn_if_nan(loss, "loss") return loss + (surrogate_loss - surrogate_loss.detach())