def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type=graph_type).get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, max_plate_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, max_plate_nesting) return model_trace, guide_trace
def _differentiable_loss_particle(self, model_trace, guide_trace): elbo_particle = 0 for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + model_site["log_prob_sum"] else: guide_site = guide_trace.nodes[name] if is_validation_enabled(): check_fully_reparametrized(guide_site) # use kl divergence if available, else fall back on sampling try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"], mask=guide_site["mask"]) assert kl_qp.shape == guide_site["fn"].batch_shape elbo_particle = elbo_particle - kl_qp.sum() except NotImplementedError: entropy_term = guide_site["score_parts"].entropy_term elbo_particle = elbo_particle + model_site["log_prob_sum"] - entropy_term.sum() # handle auxiliary sites in the guide for name, guide_site in guide_trace.nodes.items(): if guide_site["type"] == "sample" and name not in model_trace.nodes: assert guide_site["infer"].get("is_auxiliary") if is_validation_enabled(): check_fully_reparametrized(guide_site) entropy_term = guide_site["score_parts"].entropy_term elbo_particle = elbo_particle - entropy_term.sum() loss = -(elbo_particle.detach() if torch._C._get_tracing_state() else torch_item(elbo_particle)) surrogate_loss = -elbo_particle return loss, surrogate_loss
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate")] if enumerated_sites: warnings.warn('\n'.join([ 'Trace_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) yield model_trace, guide_trace
def _get_traces(self, model, guide, args, kwargs): if self.max_plate_nesting == float("inf"): with validation_enabled( False): # Avoid calling .log_prob() when undefined. # TODO factor this out as a stand-alone helper. ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs) vectorize = pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting) # Trace the guide as in ELBO. with poutine.trace() as tr, vectorize: guide(*args, **kwargs) guide_trace = tr.trace # Trace the model, drawing posterior predictive samples. with poutine.trace() as tr, poutine.uncondition(): with poutine.replay(trace=guide_trace), vectorize: model(*args, **kwargs) model_trace = tr.trace for site in model_trace.nodes.values(): if site["type"] == "sample" and site["infer"].get( "was_observed", False): site["is_observed"] = True if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): for site in guide_trace.nodes.values(): if site["type"] == "sample": warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires fully reparametrized guides" ) for trace in model_trace.nodes.values(): if site["type"] == "sample": if site["is_observed"]: warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires reparametrized likelihoods" ) if self.prior_scale > 0: model_trace.compute_log_prob( site_filter=lambda name, site: not site["is_observed"]) if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": if not site["is_observed"]: check_site_shape(site, self.max_plate_nesting) return guide_trace, model_trace
def _populate_cache(self, model_trace): """ Populate the ordinals (set of ``CondIndepStack`` frames) and enum_dims for each sample site. """ if not self.has_enumerable_sites: return if self.max_plate_nesting is None: raise ValueError( "Finite value required for `max_plate_nesting` when model " "has discrete (enumerable) sites." ) model_trace.compute_log_prob() model_trace.pack_tensors() for name, site in model_trace.nodes.items(): if site["type"] == "sample" and not isinstance(site["fn"], _Subsample): if is_validation_enabled(): check_site_shape(site, self.max_plate_nesting) self.ordering[name] = frozenset( model_trace.plate_to_symbol[f.name] for f in site["cond_indep_stack"] if f.vectorized ) self._enum_dims = set(model_trace.symbol_to_dim) - set( model_trace.plate_to_symbol.values() )
def _differentiable_loss_particle(self, model_trace, guide_trace): if not self.vectorize_particles: raise NotImplementedError("TraceTailAdaptive_ELBO only implemented for vectorize_particles==True") if self.num_particles == 1: warnings.warn("For num_particles==1 TraceTailAdaptive_ELBO uses the same loss function as Trace_ELBO. " + "Increase num_particles to get an adaptive f-divergence.") log_p, log_q = 0, 0 for name, site in model_trace.nodes.items(): if site["type"] == "sample": site_log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) log_p = log_p + site_log_p for name, site in guide_trace.nodes.items(): if site["type"] == "sample": site_log_q = site["log_prob"].reshape(self.num_particles, -1).sum(-1) log_q = log_q + site_log_q if is_validation_enabled(): check_fully_reparametrized(site) # rank the particles according to p/q log_pq = log_p - log_q rank = torch.argsort(log_pq, descending=False) rank = torch.index_select(torch.arange(self.num_particles, device=log_pq.device) + 1, -1, rank).type_as(log_pq) # compute the particle-specific weights used to construct the surrogate loss gamma = torch.pow(rank, self.tail_adaptive_beta).detach() surrogate_loss = -(log_pq * gamma).sum() / gamma.sum() # we do not compute the loss, so return `inf` return float('inf'), surrogate_loss
def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 surrogate_loss = 0.0 for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs): model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) # compute elbo for reparameterized nodes non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs) surrogate_elbo += surrogate_elbo_term loss = loss - weight * elbo surrogate_loss = surrogate_loss - weight * surrogate_elbo return loss, surrogate_loss
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide, graph_type="dense").get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="dense").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate")] if enumerated_sites: warnings.warn('\n'.join([ 'TraceGraph_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide, graph_type="dense").get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="dense").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [ name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate") ] if enumerated_sites: warnings.warn('\n'.join([ 'TraceGraph_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.' ])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, args, kwargs) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) _check_tmc_elbo_constraint(model_trace, guide_trace) has_enumerated_sites = any(site["infer"].get("enumerate") for trace in (guide_trace, model_trace) for name, site in trace.nodes.items() if site["type"] == "sample") if self.strict_enumeration_warning and not has_enumerated_sites: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) guide_trace.pack_tensors() model_trace.pack_tensors(guide_trace.plate_to_symbol) return model_trace, guide_trace
def __call__(self, name, fn, obs): fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.Stable) and fn.coords == "S0" if is_validation_enabled(): if not (fn.skew == 0).all(): raise ValueError("SymmetricStableReparam found nonzero skew") if not (fn.stability < 2).all(): raise ValueError("SymmetricStableReparam found stability >= 2") # Draw parameter-free noise. proto = fn.stability half_pi = proto.new_full(proto.shape, math.pi / 2) one = proto.new_ones(proto.shape) u = pyro.sample("{}_uniform".format(name), self._wrap(dist.Uniform(-half_pi, half_pi), event_dim)) e = pyro.sample("{}_exponential".format(name), self._wrap(dist.Exponential(one), event_dim)) # Differentiably transform to scale drawn from a totally-skewed stable variable. a = fn.stability z = _unsafe_standard_stable(a / 2, 1, u, e, coords="S") assert (z >= 0).all() scale = fn.scale * (math.pi / 4 * a).cos().pow(a.reciprocal()) * z.sqrt() scale = scale.clamp(min=torch.finfo(scale.dtype).tiny) # Construct a scaled Gaussian, using Stable(2,0,s,m) == Normal(m,s*sqrt(2)). new_fn = self._wrap(dist.Normal(fn.loc, scale * (2 ** 0.5)), event_dim) return new_fn, obs
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) yield model_trace, guide_trace
def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, args, kwargs) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace
def _disallow_latent_variables(section_name): if not is_validation_enabled(): yield return with poutine.trace() as tr: yield for name, site in tr.trace.nodes.items(): if site["type"] == "sample" and not site["is_observed"]: raise NotImplementedError("{} contained latent variable {}" .format(section_name, name))
def __setitem__(self, key, value): if self._locked: raise RuntimeError("Guide cannot write to SMCState") if is_validation_enabled(): if not isinstance(value, torch.Tensor): raise TypeError( "Only Tensors can be stored in an SMCState, but got {}". format(type(value).__name__)) if value.dim() == 0 or value.size(0) != self._num_particles: raise ValueError( "Expected leading dim of size {} but got shape {}".format( self._num_particles, value.shape)) super().__setitem__(key, value)
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.') yield model_trace, guide_trace
def _get_log_factors(self, model_trace): """ Aggregates the `log_prob` terms into a list for each ordinal. """ model_trace.compute_log_prob() model_trace.pack_tensors() log_probs = OrderedDict() # Collect log prob terms per independence context. for name, site in model_trace.nodes.items(): if site["type"] == "sample" and not isinstance(site["fn"], _Subsample): if is_validation_enabled(): check_site_shape(site, self.max_plate_nesting) log_probs.setdefault(self.ordering[name], []).append(site["packed"]["log_prob"]) return log_probs
def _pyro_sample(self, msg): if msg["done"] or msg["is_observed"] or type(msg["fn"]).__name__ == "_Subsample": return with torch.no_grad(): value = self.init_fn(msg) if is_validation_enabled() and msg["value"] is not None: if not isinstance(value, type(msg["value"])): raise ValueError( "{} provided invalid type for site {}:\nexpected {}\nactual {}" .format(self.init_fn, msg["name"], type(msg["value"]), type(value))) if value.shape != msg["value"].shape: raise ValueError( "{} provided invalid shape for site {}:\nexpected {}\nactual {}" .format(self.init_fn, msg["name"], msg["value"].shape, value.shape)) msg["value"] = value msg["done"] = True
def _transition_bwd(self, params, prev, curr, t): """ Helper to collect probabilty factors from .transition() conditioned on previous and current enumerated states. """ # Run .transition() conditioned on computed flows. cond_data = {"{}_{}".format(k, t): v for k, v in curr.items()} cond_data.update(self.compute_flows(prev, curr, t)) with poutine.condition(data=cond_data): state = prev.copy() self.transition(params, state, t) # Mutates state. # Validate that .transition() matches .compute_flows(). if is_validation_enabled(): for key in self.compartments: if not torch.allclose(state[key], curr[key]): raise ValueError("Incorrect state['{}'] update in .transition(), " "check that .transition() matches .compute_flows()." .format(key))
def _compute_log_prob_terms(self, model_trace): """ Computes the conditional probabilities for each of the sites in the model trace, and stores the result in `self._log_probs`. """ model_trace.compute_log_prob() self._log_probs = defaultdict(list) ordering = {name: frozenset(site["cond_indep_stack"]) for name, site in model_trace.nodes.items() if site["type"] == "sample"} # Collect log prob terms per independence context. for name, site in model_trace.nodes.items(): if site["type"] == "sample": if is_validation_enabled(): check_site_shape(site, self.max_plate_nesting) self._log_probs[ordering[name]].append(site["log_prob"]) if not self._log_prob_shapes: for ordinal, log_prob in self._log_probs.items(): self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal]))
def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ Guesses max_plate_nesting by running the (model,guide) pair once without enumeration. This optimistically assumes static model structure. """ # Ignore validation to allow model-enumerated sites absent from the guide. with poutine.block(): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(model, trace=guide_trace) ).get_trace(*args, **kwargs) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) sites = [ site for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "sample" ] # Validate shapes now, since shape constraints will be weaker once # max_plate_nesting is changed from float('inf') to some finite value. # Here we know the traces are not enumerated, but later we'll need to # allow broadcasting of dims to the left of max_plate_nesting. if is_validation_enabled(): guide_trace.compute_log_prob() model_trace.compute_log_prob() for site in sites: check_site_shape(site, max_plate_nesting=float("inf")) dims = [ frame.dim for site in sites for frame in site["cond_indep_stack"] if frame.vectorized ] self.max_plate_nesting = -min(dims) if dims else 0 if self.vectorize_particles and self.num_particles > 1: self.max_plate_nesting += 1 logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting))
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # have the trace compute all the individual (batch) log pdf terms # and score function terms (if present) so that they are available below model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) # compute elbo for reparameterized nodes non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam( guide_trace, non_reparam_nodes, downstream_costs) surrogate_elbo += surrogate_elbo_term # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) loss = -torch_item(elbo) if torch_isnan(loss): warnings.warn('Encountered NAN loss') return weight * loss
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # have the trace compute all the individual (batch) log pdf terms # and score function terms (if present) so that they are available below model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) # compute elbo for reparameterized nodes non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs) surrogate_elbo += surrogate_elbo_term # collect parameters to train from model and guide trainable_params = any(site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values()) if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) loss = -torch_item(elbo) if torch_isnan(loss): warnings.warn('Encountered NAN loss') return weight * loss
def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 surrogate_loss = 0.0 for weight, model_trace, guide_trace in self._get_traces( model, guide, *args, **kwargs): model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) # compute elbo for reparameterized nodes non_reparam_nodes = set( guide_trace.nonreparam_stochastic_nodes) elbo, surrogate_elbo = _compute_elbo_reparam( model_trace, guide_trace, non_reparam_nodes) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam( guide_trace, non_reparam_nodes, downstream_costs) surrogate_elbo += surrogate_elbo_term loss = loss - weight * elbo surrogate_loss = surrogate_loss - weight * surrogate_elbo return loss, surrogate_loss
def _get_trace(self, model, guide, *args, **kwargs): model_trace, guide_trace = super(TraceMeanField_ELBO, self)._get_trace( model, guide, *args, **kwargs) if is_validation_enabled(): _check_mean_field_requirement(model_trace, guide_trace) return model_trace, guide_trace
def _get_trace(self, model, guide, args, kwargs): model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) if is_validation_enabled(): _check_mean_field_requirement(model_trace, guide_trace) return model_trace, guide_trace