def _differentiable_loss_particle(self, model_trace, guide_trace): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item( site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site[ "log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] elbo_particle = elbo_particle - torch_item( site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum( ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = surrogate_elbo_particle + ( site * score_function_term).sum() return -elbo_particle, -surrogate_elbo_particle
def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False): pyro.clear_param_store() def model(): with pyro.plate('samples', self.sample_batch_size): pyro.sample( "loc_latent", dist.Normal( torch.stack([self.loc0] * self.sample_batch_size, dim=0), torch.stack([torch.pow(self.lam0, -0.5)] * self.sample_batch_size, dim=0)).to_event(1)) def guide(): loc_q = pyro.param("loc_q", self.loc0.detach() + 0.134) log_sig_q = pyro.param( "log_sig_q", -0.5 * torch.log(self.lam0).data.detach() - 0.14) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal with pyro.plate('samples', self.sample_batch_size): pyro.sample( "loc_latent", Normal( torch.stack([loc_q] * self.sample_batch_size, dim=0), torch.stack([sig_q] * self.sample_batch_size, dim=0)).to_event(1)) adam = optim.Adam({"lr": .001}) svi = SVI(model, guide, adam, loss=loss) alpha = 0.99 for k in range(n_steps): svi.step() if debug: loc_error = param_mse("loc_q", self.loc0) log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0)) with torch.no_grad(): if k == 0: avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts( model, guide) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) loglikelihood, penalty = loss._differentiable_loss_parts( model, guide) avg_loglikelihood = alpha * avg_loglikelihood + ( 1 - alpha) * torch_item(loglikelihood) avg_penalty = alpha * avg_penalty + ( 1 - alpha) * torch_item(penalty) if k % 100 == 0: print(loc_error, log_sig_error) print(avg_loglikelihood, avg_penalty) print() loc_error = param_mse("loc_q", self.loc0) log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0)) assert_equal(0.0, loc_error, prec=0.05) assert_equal(0.0, log_sig_error, prec=0.05)
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo = 0.0 # grab a trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item( site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site[ "log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] elbo_particle = elbo_particle - torch_item( site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum( ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = surrogate_elbo_particle + ( site * score_function_term).sum() elbo += elbo_particle / self.num_particles # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False): surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles surrogate_loss_particle.backward() loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): elbo_particle = 0. # compute elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_vectorized: log_prob_sum = site["log_prob"].detach().reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = torch_item(site["log_prob_sum"]) elbo_particle = elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] if is_vectorized: log_prob_sum = log_prob.detach().reshape( self.num_particles, -1).sum(-1) else: log_prob_sum = torch_item(site["log_prob_sum"]) elbo_particle = elbo_particle - log_prob_sum elbo_particles.append(elbo_particle) if is_vectorized: elbo_particles = elbo_particles[0] else: elbo_particles = torch.tensor( elbo_particles) # no need to use .new*() here log_weights = (1. - self.alpha) * elbo_particles log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log( self.num_particles) elbo = log_mean_weight.sum().item() / (1. - self.alpha) loss = -elbo warn_if_nan(loss, "loss") return loss
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = 0. # compute elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_vectorized: log_prob_sum = site["log_prob"].detach().reshape(self.num_particles, -1).sum(-1) else: log_prob_sum = torch_item(site["log_prob_sum"]) elbo_particle = elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] if is_vectorized: log_prob_sum = log_prob.detach().reshape(self.num_particles, -1).sum(-1) else: log_prob_sum = torch_item(site["log_prob_sum"]) elbo_particle = elbo_particle - log_prob_sum elbo_particles.append(elbo_particle) if is_vectorized: elbo_particles = elbo_particles[0] else: elbo_particles = torch.tensor(elbo_particles) # no need to use .new*() here #print('elbo_particles', elbo_particles.size()) elbo_particles = elbo_particles.view(self.num_outer, self.num_inner) #sys.exit() log_w_norm = elbo_particles - torch.logsumexp(elbo_particles, dim=1, keepdim=True) ess_val = torch.exp(-torch.logsumexp(2*log_w_norm, dim=1)) #print('ess_val', ess_val.size()) #sys.exit() loss = ess_val.mean() warn_if_nan(loss, "loss") return loss.item()
def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False): pyro.clear_param_store() Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta def model(): with pyro.plate('samples', self.sample_batch_size): pyro.sample( "p_latent", Beta( torch.stack([torch.stack([self.alpha0])]*self.sample_batch_size), torch.stack([torch.stack([self.beta0])]*self.sample_batch_size) ).to_event(1) ) def guide(): alpha_q_log = pyro.param("alpha_q_log", torch.log(self.alpha0) + 0.17) beta_q_log = pyro.param("beta_q_log", torch.log(self.beta0) - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) with pyro.plate('samples', self.sample_batch_size): pyro.sample( "p_latent", Beta( torch.stack([torch.stack([alpha_q])]*self.sample_batch_size), torch.stack([torch.stack([beta_q])]*self.sample_batch_size) ).to_event(1) ) adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=loss) alpha = 0.99 for k in range(n_steps): svi.step() if debug: alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0)) beta_error = param_abs_error("beta_q_log", torch.log(self.beta0)) with torch.no_grad(): if k == 0: avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) loglikelihood, penalty = loss._differentiable_loss_parts(model, guide) avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood) avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty) if k % 100 == 0: print(alpha_error, beta_error) print(avg_loglikelihood, avg_penalty) print() alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0)) beta_error = param_abs_error("beta_q_log", torch.log(self.beta0)) assert_equal(0.0, alpha_error, prec=0.08) assert_equal(0.0, beta_error, prec=0.08)
def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False, lr=0.0002): pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma def model(): with pyro.plate('samples', self.sample_batch_size): pyro.sample( "lambda_latent", Gamma( torch.stack([torch.stack([self.alpha0])]*self.sample_batch_size), torch.stack([torch.stack([self.beta0])]*self.sample_batch_size) ).to_event(1) ) def guide(): alpha_q = pyro.param("alpha_q", self.alpha0.detach() + math.exp(0.17), constraint=constraints.positive) beta_q = pyro.param("beta_q", self.beta0.detach() / math.exp(0.143), constraint=constraints.positive) with pyro.plate('samples', self.sample_batch_size): pyro.sample( "lambda_latent", Gamma( torch.stack([torch.stack([alpha_q])]*self.sample_batch_size), torch.stack([torch.stack([beta_q])]*self.sample_batch_size) ).to_event(1) ) adam = optim.Adam({"lr": lr, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss) alpha = 0.99 for k in range(n_steps): svi.step() if debug: alpha_error = param_mse("alpha_q", self.alpha0) beta_error = param_mse("beta_q", self.beta0) with torch.no_grad(): if k == 0: avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide, (), {}) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) loglikelihood, penalty = loss._differentiable_loss_parts(model, guide, (), {}) avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood) avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty) if k % 100 == 0: print(alpha_error, beta_error) print(avg_loglikelihood, avg_penalty) print() assert_equal(pyro.param("alpha_q"), self.alpha0, prec=0.2, msg='{} vs {}'.format( pyro.param("alpha_q").detach().cpu().numpy(), self.alpha0.detach().cpu().numpy())) assert_equal(pyro.param("beta_q"), self.beta0, prec=0.15, msg='{} vs {}'.format( pyro.param("beta_q").detach().cpu().numpy(), self.beta0.detach().cpu().numpy()))
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum()) elbo += elbo_particle / float(self.num_particles) loss = -elbo warn_if_nan(loss, "loss") return loss
def _differentiable_loss_particle(self, model_trace, guide_trace): elbo_particle = 0 for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + model_site["log_prob_sum"] else: guide_site = guide_trace.nodes[name] if is_validation_enabled(): check_fully_reparametrized(guide_site) # use kl divergence if available, else fall back on sampling try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"], mask=guide_site["mask"]) assert kl_qp.shape == guide_site["fn"].batch_shape elbo_particle = elbo_particle - kl_qp.sum() except NotImplementedError: entropy_term = guide_site["score_parts"].entropy_term elbo_particle = elbo_particle + model_site["log_prob_sum"] - entropy_term.sum() # handle auxiliary sites in the guide for name, guide_site in guide_trace.nodes.items(): if guide_site["type"] == "sample" and name not in model_trace.nodes: assert guide_site["infer"].get("is_auxiliary") if is_validation_enabled(): check_fully_reparametrized(guide_site) entropy_term = guide_site["score_parts"].entropy_term elbo_particle = elbo_particle - entropy_term.sum() loss = -(elbo_particle.detach() if torch._C._get_tracing_state() else torch_item(elbo_particle)) surrogate_loss = -elbo_particle return loss, surrogate_loss
def loss_and_grads(self, model, guide, *args, **kwargs): loss = self.differentiable_loss(model, guide, *args, **kwargs) loss.backward() loss = torch_item(loss) warn_if_nan(loss, "loss") return loss
def step(self, *args, **kwargs): """ :returns: estimate of the loss :rtype: float Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by `loss_and_grads`). Any args or kwargs are passed to the model and guide """ # get loss and compute gradients with poutine.trace(param_only=True) as param_capture: loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs) params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values()) # actually perform gradient steps # torch.optim objects gets instantiated for any params that haven't been seen yet self.optim(params) # zero gradients pyro.infer.util.zero_grads(params) if isinstance(loss, tuple): # Support losses that return a tuple, e.g. ReweightedWakeSleep. return type(loss)(map(torch_item, loss)) else: return torch_item(loss)
def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1000): """ A helper to optimize parameters for a GP module. :param ~pyro.contrib.gp.models.GPModel gpmodule: A GP module. :param ~torch.optim.Optimizer optimizer: A PyTorch optimizer instance. By default, we use Adam with ``lr=0.01``. :param callable loss_fn: A loss function which takes inputs are ``gpmodule.model``, ``gpmodule.guide``, and returns ELBO loss. By default, ``loss_fn=TraceMeanField_ELBO().differentiable_loss``. :param bool retain_graph: An optional flag of ``torch.autograd.backward``. :param int num_steps: Number of steps to run SVI. :returns: a list of losses during the training procedure :rtype: list """ optimizer = ( torch.optim.Adam(gpmodule.parameters(), lr=0.01) if optimizer is None else optimizer ) # TODO: add support for JIT loss loss_fn = TraceMeanField_ELBO().differentiable_loss if loss_fn is None else loss_fn def closure(): optimizer.zero_grad() loss = loss_fn(gpmodule.model, gpmodule.guide) torch_backward(loss, retain_graph) return loss losses = [] for i in range(num_steps): loss = optimizer.step(closure) losses.append(torch_item(loss)) return losses
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum()) elbo += elbo_particle / self.num_particles loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo = 0.0 for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum()) elbo += weight * elbo_particle loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def evaluate_loss(self, *args, **kwargs): """ :returns: estimate of the loss :rtype: float Evaluate the loss function. Any args or kwargs are passed to the model and guide. """ with torch.no_grad(): return torch_item(self.loss(self.model, self.guide, *args, **kwargs))
def step(self, *args, **kwargs): with poutine.trace(param_only=True) as param_capture: loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs) params = [] for site in param_capture.trace.nodes.values(): param = site["value"].unconstrained() if site.get('free') is not None: param.grad = site['free'] * param.grad params.append(param) self.optim(params) pyro.infer.util.zero_grads(params) return torch_item(loss)
def differentiable_loss(self, model, guide, *args, **kwargs): """ Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters """ loss = 0. surrogate_loss = 0. for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace) surrogate_loss += surrogate_loss_particle / self.num_particles loss += loss_particle / self.num_particles warn_if_nan(surrogate_loss, "loss") return loss + (surrogate_loss - torch_item(surrogate_loss))
def evaluate_loss(self, *args, **kwargs): """ :returns: estimate of the loss :rtype: float Evaluate the loss function. Any args or kwargs are passed to the model and guide. """ with torch.no_grad(): loss = self.loss(self.model, self.guide, *args, **kwargs) if isinstance(loss, tuple): # Support losses that return a tuple, e.g. ReweightedWakeSleep. return type(loss)(map(torch_item, loss)) else: return torch_item(loss)
def custom_step(self, *batch): with poutine.trace(param_only=True) as param_capture: loss = self.loss.differentiable_loss(self.model, self.guide, * batch) + self.beta_l1_loss() params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values()) loss.backward() self.optimizer(params) pyro.infer.util.zero_grads(params) return torch_item(loss)
def loss_and_grads(self, grads, batch, *args, **kwargs): """ :returns: an estimate of the loss (expectation over p(x, y) of -log q(x, y) ) - where p is the model and q is the guide :rtype: float If a batch is provided, the loss is estimated using these traces Otherwise, a fresh batch is generated from the model. If grads is True, will also call `backward` on loss. `args` and `kwargs` are passed to the model and guide. """ if batch is None: batch = ( self._sample_from_joint(*args, **kwargs) for _ in range(self.training_batch_size) ) batch_size = self.training_batch_size else: batch_size = len(batch) loss = 0 for model_trace in batch: with poutine.trace(param_only=True) as particle_param_capture: guide_trace = self._get_matched_trace(model_trace, *args, **kwargs) particle_loss = self._differentiable_loss_particle(guide_trace) particle_loss /= batch_size if grads: guide_params = set( site["value"].unconstrained() for site in particle_param_capture.trace.nodes.values() ) guide_grads = torch.autograd.grad( particle_loss, guide_params, allow_unused=True ) for guide_grad, guide_param in zip(guide_grads, guide_params): if guide_grad is None: continue guide_param.grad = ( guide_grad if guide_param.grad is None else guide_param.grad + guide_grad ) loss += torch_item(particle_loss) warn_if_nan(loss, "loss") return loss
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the MMD-VAE-type loss [1] :rtype: float Computes the MMD-VAE-type loss with an estimator that uses num_particles many samples/particles. References [1] `A Tutorial on Information Maximizing Variational Autoencoders (InfoVAE)` Shengjia Zhao https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/ """ loss = self.differentiable_loss(model, guide, *args, **kwargs) return torch_item(loss)
def evaluate(self, raw_expr, encoded_expr, read_depth): batch_logp = [] for batch in self.epoch_batch(raw_expr, encoded_expr, read_depth, batch_size=512, bar=False): with torch.no_grad(): log_prob = torch_item( self.loss.loss(self.model, self.guide, *batch)) batch_logp.append(log_prob) return np.array(batch_logp).sum() #loss is negative log-likelihood
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated. """ elbo, surrogate_loss = self._loss_and_surrogate_loss(model, guide, args, kwargs) torch_backward(surrogate_loss, retain_graph=self.retain_graph) elbo = torch_item(elbo) loss = -elbo warn_if_nan(loss, "loss") return loss
def _differentiable_loss_particle(self, model_trace, guide_trace): # Construct -ELBO part. blocked_names = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["is_observed"]] blocked_guide_trace = guide_trace.copy() for name in blocked_names: del blocked_guide_trace.nodes[name] loss, surrogate_loss = super()._differentiable_loss_particle( model_trace, blocked_guide_trace) # Add log q terms. for name in blocked_names: log_q = guide_trace.nodes[name]["log_prob_sum"] loss = loss - 100* torch_item(log_q) surrogate_loss = surrogate_loss - 100* log_q return loss, surrogate_loss
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the MMD-VAE-type loss [1] :rtype: float Computes the MMD-VAE-type loss and performs backward on it. Leads to valid gradient estimates as long as latent variables in both the guide and the model are reparameterizable. Num_particles many samples are used to form the estimators. References [1] `A Tutorial on Information Maximizing Variational Autoencoders (InfoVAE)` Shengjia Zhao https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/ """ loss = self.differentiable_loss(model, guide, *args, **kwargs) loss.backward(retain_graph=self.retain_graph) return torch_item(loss)
def step(self, *args, **kwargs): """ :returns: estimate of the loss :rtype: float Take a gradient step on the loss function. Arguments are passed to the model and guide. """ with poutine.trace(param_only=True) as param_capture: loss = self.loss_and_grads(True, None, *args, **kwargs) params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values() if site["value"].grad is not None) self.optim(params) pyro.infer.util.zero_grads(params) return torch_item(loss)
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # have the trace compute all the individual (batch) log pdf terms # and score function terms (if present) so that they are available below model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) # compute elbo for reparameterized nodes non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam( guide_trace, non_reparam_nodes, downstream_costs) surrogate_elbo += surrogate_elbo_term # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) loss = -torch_item(elbo) if torch_isnan(loss): warnings.warn('Encountered NAN loss') return weight * loss
def step(self, *args, **kwargs): # compute loss loss = self.loss_fn(self.model, self.guide, *args, ** kwargs) + self.compute_l2_term(*args, **kwargs) loss.backward() tr = poutine.trace(self.guide).get_trace(*args, **kwargs) params = [ site['value'].unconstrained() for name, site in tr.nodes.items() if site['type'] == 'param' ] # Copied from Pyro SVI source code # actually perform gradient steps # torch.optim objects gets instantiated for any params that haven't been seen yet self.optimizer(params) # zero gradients zero_grads(params) return torch_item(loss)
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # have the trace compute all the individual (batch) log pdf terms # and score function terms (if present) so that they are available below model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) # compute elbo for reparameterized nodes non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs) surrogate_elbo += surrogate_elbo_term # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) loss = -torch_item(elbo) if torch_isnan(loss): warnings.warn('Encountered NAN loss') return weight * loss
def step(self, *args, **kwargs): """ :returns: estimate of the loss :rtype: float Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by `loss_and_grads`). Any args or kwargs are passed to the model and guide """ # get loss and compute gradients with poutine.trace(param_only=True) as param_capture: loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs) params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values()) # actually perform gradient steps # torch.optim objects gets instantiated for any params that haven't been seen yet self.optim(params) # zero gradients pyro.infer.util.zero_grads(params) return torch_item(loss)
def loss(self, model, guide, *args, **kwargs): return torch_item( self.differentiable_loss(model, guide, *args, **kwargs))
def loss_and_grads(self, model, guide, *args, **kwargs): # TODO: add argument lambda --> assigns weights to losses # TODO: Normalize loss elbo value if not done """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo = 0.0 dyn_loss = 0.0 dim_loss = 0.0 # grab a trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None ys = [] # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item(site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] elbo_particle = elbo_particle - torch_item(site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum() if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum() if site["name"].startswith("y_"): # TODO: check order of y ys.append(site["value"]) man = torch.stack(ys, dim=1) mean_man = man.mean(dim=1, keepdims=True) man = man - mean_man dyn_loss += self._get_logdet_loss(man, delta=self.delta) # TODO: Normalize dim_loss += self._get_traceK_loss(man) elbo += elbo_particle / self.num_particles # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False): surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles \ +self.lam * dyn_loss \ +self.gam * dim_loss surrogate_loss_particle.backward() loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss, dyn_loss.item(), dim_loss.item(), man
def loss_and_grads(self, model, guide, *args, **kwargs): loss = self._loss(model, guide, args, kwargs) torch_backward(loss, retain_graph=self.retain_graph) loss = torch_item(loss) warn_if_nan(loss, "loss") return loss