def __init__(self, guide_trace, ordering): log_denom = defaultdict(float) # avoids double-counting when sequentially enumerating log_probs = defaultdict(list) # accounts for upstream probabilties for name, site in guide_trace.nodes.items(): if site["type"] != "sample": continue log_prob = site["packed"]["score_parts"].score_function # not scaled by subsampling dims = getattr(log_prob, "_pyro_dims", "") ordinal = ordering[name] if site["infer"].get("enumerate"): num_samples = site["infer"].get("num_samples") if num_samples is not None: # site was multiply sampled if not is_identically_zero(log_prob): log_prob = log_prob - log_prob.detach() log_prob = log_prob - math.log(num_samples) if not isinstance(log_prob, torch.Tensor): log_prob = site["value"].new_tensor(log_prob) log_prob._pyro_dims = dims # I don't know why the following broadcast is needed, but it makes tests pass: log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"]) elif site["infer"]["enumerate"] == "sequential": log_denom[ordinal] += math.log(site["infer"]["_enum_total"]) else: # site was monte carlo sampled if is_identically_zero(log_prob): continue log_prob = log_prob - log_prob.detach() log_prob._pyro_dims = dims log_probs[ordinal].append(log_prob) self.log_denom = log_denom self.log_probs = log_probs
def _differentiable_loss_particle(self, model_trace, guide_trace): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item( site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site[ "log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] elbo_particle = elbo_particle - torch_item( site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum( ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = surrogate_elbo_particle + ( site * score_function_term).sum() return -elbo_particle, -surrogate_elbo_particle
def compute_site_dice_factor(site): log_denom = 0 log_prob = site["packed"][ "score_parts"].score_function # not scaled by subsampling dims = getattr(log_prob, "_pyro_dims", "") if site["infer"].get("enumerate"): num_samples = site["infer"].get("num_samples") if num_samples is not None: # site was multiply sampled if not is_identically_zero(log_prob): log_prob = log_prob - log_prob.detach() log_prob = log_prob - math.log(num_samples) if not isinstance(log_prob, torch.Tensor): log_prob = torch.tensor(float(log_prob), device=site["value"].device) log_prob._pyro_dims = dims # I don't know why the following broadcast is needed, but it makes tests pass: log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"]) elif site["infer"]["enumerate"] == "sequential": log_denom = math.log(site["infer"].get("_enum_total", num_samples)) else: # site was monte carlo sampled if not is_identically_zero(log_prob): log_prob = log_prob - log_prob.detach() log_prob._pyro_dims = dims return log_prob, log_denom
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo = 0.0 # grab a trace from the generator for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item( site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site[ "log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] elbo_particle = elbo_particle - torch_item( site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum( ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = surrogate_elbo_particle + ( site * score_function_term).sum() elbo += elbo_particle / self.num_particles # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False): surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles surrogate_loss_particle.backward() loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def check_fully_reparametrized(guide_site): log_prob, score_function_term, entropy_term = guide_site["score_parts"] fully_rep = (guide_site["fn"].has_rsample and not is_identically_zero(entropy_term) and is_identically_zero(score_function_term)) if not fully_rep: raise NotImplementedError( "All distributions in the guide must be fully reparameterized.")
def _compute_elbo_reparam(model_trace, guide_trace): # In ref [1], section 3.2, the part of the surrogate loss computed here is # \sum{cost}, which in this case is the ELBO. Instead of using the ELBO, # this implementation uses a surrogate ELBO which modifies some entropy # terms depending on the parameterization. This reduces the variance of the # gradient under some conditions. elbo = 0.0 surrogate_elbo = 0.0 # Bring log p(x, z|...) terms into both the ELBO and the surrogate for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo += site["log_prob_sum"] surrogate_elbo += site["log_prob_sum"] # Bring log q(z|...) terms into the ELBO, and effective terms into the # surrogate. Depending on the parameterization of a site, its log q(z|...) # cost term may not contribute (in expectation) to the gradient. To reduce # the variance under some conditions, the default entropy terms from # site[`score_parts`] are used. for name, site in guide_trace.nodes.items(): if site["type"] == "sample": elbo -= site["log_prob_sum"] entropy_term = site["score_parts"].entropy_term # For fully reparameterized terms, this entropy_term is log q(z|...) # For fully non-reparameterized terms, it is zero if not is_identically_zero(entropy_term): surrogate_elbo -= entropy_term.sum() return elbo, surrogate_elbo
def differentiable_loss(self, model, guide, *args, **kwargs): """ :returns: a differentiable estimate of the ELBO :rtype: torch.Tensor :raises ValueError: if the ELBO is not differentiable (e.g. is identically zero) Estimates a differentiable ELBO using ``num_particles`` many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented). """ elbo = 0.0 for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): elbo_particle = _compute_dice_elbo(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo = elbo + elbo_particle elbo = elbo / self.num_particles if not torch.is_tensor(elbo) or not elbo.requires_grad: raise ValueError( 'ELBO is cannot be differentiated: {}'.format(elbo)) loss = -elbo warn_if_nan(loss, "loss") return loss
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Estimates the ELBO using ``num_particles`` many samples (particles). Performs backward on the ELBO of each particle. """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = _compute_dice_elbo(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo += elbo_particle.item() / self.num_particles # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and elbo_particle.requires_grad: loss_particle = -elbo_particle (loss_particle / self.num_particles).backward(retain_graph=True) loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: an estimate of the ELBO :rtype: float Estimates the ELBO using ``num_particles`` many samples (particles). Performs backward on the ELBO of each particle. """ elbo = 0.0 for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): elbo_particle = _compute_dice_elbo(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo += elbo_particle.item() / self.num_particles # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and elbo_particle.requires_grad: loss_particle = -elbo_particle (loss_particle / self.num_particles).backward(retain_graph=True) loss = -elbo warn_if_nan(loss, "loss") return loss
def loss_and_surrogate_loss(*args, **kwargs): kwargs.pop("_pyro_model_id") kwargs.pop("_pyro_guide_id") self = weakself() loss = 0.0 surrogate_loss = 0.0 for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + site["log_prob_sum"] surrogate_elbo_particle = ( surrogate_elbo_particle + site["log_prob_sum"]) for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts"] elbo_particle = elbo_particle - site["log_prob_sum"] if not is_identically_zero(entropy_term): surrogate_elbo_particle = ( surrogate_elbo_particle - entropy_term.sum()) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r( model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = ( surrogate_elbo_particle + (site * score_function_term).sum()) loss = loss - elbo_particle / self.num_particles surrogate_loss = ( surrogate_loss - surrogate_elbo_particle / self.num_particles) return loss, surrogate_loss
def __init__(self, guide_trace, ordering): log_denoms = defaultdict(float) # avoids double-counting when sequentially enumerating log_probs = defaultdict(list) # accounts for upstream probabilties for name, site in guide_trace.nodes.items(): if site["type"] != "sample": continue ordinal = ordering[name] log_prob, log_denom = compute_site_dice_factor(site) if not is_identically_zero(log_prob): log_probs[ordinal].append(log_prob) if not is_identically_zero(log_denom): log_denoms[ordinal] += log_denom self.log_denom = log_denoms self.log_probs = log_probs
def _compute_dice_factors(model_trace, guide_trace): """ compute per-site DiCE log-factors for non-reparameterized proposal sites this logic is adapted from pyro.infer.util.Dice.__init__ """ log_probs = [] for role, trace in zip(("model", "guide"), (model_trace, guide_trace)): for name, site in trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue if role == "model" and name in guide_trace: continue log_prob, log_denom = compute_site_dice_factor(site) if not is_identically_zero(log_denom): dims = log_prob._pyro_dims log_prob = log_prob - log_denom log_prob._pyro_dims = dims if not is_identically_zero(log_prob): log_probs.append(log_prob) return log_probs
def _get_log_factors(self, target_ordinal): """ Returns a list of DiCE factors at a given ordinal. """ log_denom = 0 for ordinal, term in self.log_denom.items(): if not ordinal <= target_ordinal: # not downstream log_denom += term # term = log(# times this ordinal is counted) log_factors = [] if is_identically_zero(log_denom) else [-log_denom] for ordinal, terms in self.log_probs.items(): if ordinal <= target_ordinal: # upstream log_factors.extend(terms) # terms = [log(dice weight of this ordinal)] return log_factors
def loss(self, model, guide, *args, **kwargs): """ :returns: an estimate of the ELBO :rtype: float Estimates the ELBO using ``num_particles`` many samples (particles). """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = _compute_dice_elbo(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo += elbo_particle.item() / self.num_particles loss = -elbo warn_if_nan(loss, "loss") return loss
def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes): elbo = 0.0 surrogate_elbo = 0.0 # deal with log p(z|...) terms for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo += site["log_prob_sum"] surrogate_elbo += site["log_prob_sum"] # deal with log q(z|...) terms for name, site in guide_trace.nodes.items(): if site["type"] == "sample": elbo -= site["log_prob_sum"] entropy_term = site["score_parts"].entropy_term if not is_identically_zero(entropy_term): surrogate_elbo -= entropy_term.sum() return elbo, surrogate_elbo
def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes): elbo = 0.0 surrogate_elbo = 0.0 # deal with log p(z|...) terms for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo += site["log_prob_sum"] surrogate_elbo += site["log_prob_sum"] # deal with log q(z|...) terms for name, site in guide_trace.nodes.items(): if site["type"] == "sample": elbo -= site["log_prob_sum"] entropy_term = site["score_parts"].entropy_term if not is_identically_zero(entropy_term): surrogate_elbo -= entropy_term.sum() return elbo, surrogate_elbo
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Estimates the ELBO using ``num_particles`` many samples (particles). """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = _compute_dice_elbo(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo += elbo_particle.item() / self.num_particles loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Estimates the ELBO using ``num_particles`` many samples (particles). """ elbo = 0.0 for model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): elbo_particle = _compute_dice_elbo(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo += elbo_particle.item() / self.num_particles loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss
def _get_log_factors(self, target_ordinal): """ Returns a list of DiCE factors ordinal. """ # memoize try: return self._log_factors_cache[target_ordinal] except KeyError: pass log_denom = 0 for ordinal, term in self.log_denom.items(): if not ordinal <= target_ordinal: # not downstream log_denom += term # term = log(# times this ordinal is counted) log_factors = [] if is_identically_zero(log_denom) else [-log_denom] for ordinal, term in self.log_probs.items(): if ordinal <= target_ordinal: # upstream log_factors += term # term = [log(dice weight of this ordinal)] self._log_factors_cache[target_ordinal] = log_factors return log_factors
def differentiable_loss(self, model, guide, *args, **kwargs): """ :returns: a differentiable estimate of the marginal log-likelihood :rtype: torch.Tensor :raises ValueError: if the ELBO is not differentiable (e.g. is identically zero) Computes a differentiable TMC estimate using ``num_particles`` many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented). """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = _compute_tmc_estimate(model_trace, guide_trace) if is_identically_zero(elbo_particle): continue elbo = elbo + elbo_particle elbo = elbo / self.num_particles loss = -elbo warn_if_nan(loss, "loss") return loss
def __init__(self, guide_trace, ordering): log_denom = defaultdict(lambda: 0.0) # avoids double-counting when sequentially enumerating log_probs = defaultdict(list) # accounts for upstream probabilties for name, site in guide_trace.nodes.items(): if site["type"] != "sample": continue log_prob = site['score_parts'].score_function # not scaled by subsampling if is_identically_zero(log_prob): continue ordinal = ordering[name] if site["infer"].get("enumerate"): if site["infer"]["enumerate"] == "sequential": log_denom[ordinal] += math.log(site["infer"]["_enum_total"]) else: # site was monte carlo sampled log_prob = log_prob - log_prob.detach() log_probs[ordinal].append(log_prob) self.log_denom = log_denom self.log_probs = log_probs self._log_factors_cache = {} self._prob_cache = {}
def get_deltas(self, save_params=None): deltas = {} aux_values = {} compute_density = poutine.get_mask() is not False for name, site in self._sorted_sites: if save_params is not None and name not in save_params: continue # Sample zero-mean blockwise independent Delta/Normal/MVN. log_density = 0.0 loc = deep_getattr(self.locs, name) zero = torch.zeros_like(loc) conditional = self.conditionals[name] if callable(conditional): aux_value = deep_getattr(self.conds, name)() elif conditional == "delta": aux_value = zero elif conditional == "normal": aux_value = pyro.sample( name + "_aux", dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) scale = deep_getattr(self.scales, name) aux_value = aux_value * scale if compute_density: log_density = (-scale.log()).expand_as(aux_value) elif conditional == "mvn": # This overparametrizes by learning (scale,scale_tril), # enabling faster learning of the more-global scale parameter. aux_value = pyro.sample( name + "_aux", dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) scale = deep_getattr(self.scales, name) scale_tril = deep_getattr(self.scale_trils, name) aux_value = aux_value @ scale_tril.T * scale if compute_density: log_density = ( -scale_tril.diagonal(dim1=-2, dim2=-1).log() - scale.log()).expand_as(aux_value) else: raise ValueError( f"Unsupported conditional type: {conditional}") # Accumulate upstream dependencies. # Note: by accumulating upstream dependencies before updating the # aux_values dict, we encode a block-sparse structure of the # precision matrix; if we had instead accumulated after updating # aux_values, we would encode a block-sparse structure of the # covariance matrix. # Note: these shear transforms have no effect on the Jacobian # determinant, and can therefore be excluded from the log_density # computation below, even for nonlinear dep(). deps = deep_getattr(self.deps, name) for upstream in self.dependencies.get(name, {}): dep = deep_getattr(deps, upstream) aux_value = aux_value + dep(aux_values[upstream]) aux_values[name] = aux_value # Shift by loc and reshape. batch_shape = torch.broadcast_shapes(aux_value.shape[:-1], self._batch_shapes[name]) unconstrained = ( aux_value + loc).reshape(batch_shape + self._unconstrained_event_shapes[name]) if not is_identically_zero(log_density): log_density = log_density.reshape(batch_shape + (-1, )).sum(-1) # Transform to constrained space. transform = biject_to(site["fn"].support) value = transform(unconstrained) if compute_density and conditional != "delta": assert transform.codomain.event_dim == site["fn"].event_dim log_density = log_density + transform.inv.log_abs_det_jacobian( value, unconstrained) # Create a reparametrized Delta distribution. deltas[name] = dist.Delta(value, log_density, site["fn"].event_dim) return deltas
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo_particles = [] surrogate_elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 tensor_holder = None # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_vectorized: log_prob_sum = site["log_prob"].reshape(self.num_particles, -1).sum(-1) else: log_prob_sum = site["log_prob_sum"] elbo_particle = elbo_particle + log_prob_sum.detach() surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] if is_vectorized: log_prob_sum = log_prob.reshape(self.num_particles, -1).sum(-1) else: log_prob_sum = site["log_prob_sum"] elbo_particle = elbo_particle - log_prob_sum.detach() if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum if not is_identically_zero(score_function_term): # link to the issue: https://github.com/pyro-ppl/pyro/issues/1222 raise NotImplementedError if not is_identically_zero(score_function_term): surrogate_elbo_particle = (surrogate_elbo_particle + (self.alpha / (1. - self.alpha)) * log_prob_sum) if is_identically_zero(elbo_particle): if tensor_holder is not None: elbo_particle = torch.zeros_like(tensor_holder) surrogate_elbo_particle = torch.zeros_like(tensor_holder) else: # elbo_particle is not None if tensor_holder is None: tensor_holder = torch.zeros_like(elbo_particle) # change types of previous `elbo_particle`s for i in range(len(elbo_particles)): elbo_particles[i] = torch.zeros_like(tensor_holder) surrogate_elbo_particles[i] = torch.zeros_like(tensor_holder) elbo_particles.append(elbo_particle) surrogate_elbo_particles.append(surrogate_elbo_particle) if tensor_holder is None: return 0. if is_vectorized: elbo_particles = elbo_particles[0] surrogate_elbo_particles = surrogate_elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) surrogate_elbo_particles = torch.stack(surrogate_elbo_particles) log_weights = (1. - self.alpha) * elbo_particles log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log(self.num_particles) elbo = log_mean_weight.sum().item() / (1. - self.alpha) # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particles, 'requires_grad', False): normalized_weights = (log_weights - log_mean_weight).exp() surrogate_elbo = (normalized_weights * surrogate_elbo_particles).sum() / self.num_particles surrogate_loss = -surrogate_elbo surrogate_loss.backward() loss = -elbo warn_if_nan(loss, "loss") return loss
def loss_and_grads(self, model, guide, *args, **kwargs): # TODO: add argument lambda --> assigns weights to losses # TODO: Normalize loss elbo value if not done """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo = 0.0 dyn_loss = 0.0 dim_loss = 0.0 # grab a trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None ys = [] # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item(site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] elbo_particle = elbo_particle - torch_item(site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum() if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum() if site["name"].startswith("y_"): # TODO: check order of y ys.append(site["value"]) man = torch.stack(ys, dim=1) mean_man = man.mean(dim=1, keepdims=True) man = man - mean_man dyn_loss += self._get_logdet_loss(man, delta=self.delta) # TODO: Normalize dim_loss += self._get_traceK_loss(man) elbo += elbo_particle / self.num_particles # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False): surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles \ +self.lam * dyn_loss \ +self.gam * dim_loss surrogate_loss_particle.backward() loss = -elbo if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss, dyn_loss.item(), dim_loss.item(), man
def loss_and_grads(self, model, guide, *args, **kwargs): loss = self.differentiable_loss(model, guide, *args, **kwargs) if is_identically_zero(loss) or not loss.requires_grad: return torch_item(loss) loss.backward() return loss.item()