def score_parts(self, x, *args, **kwargs): """ Computes ingredients for stochastic gradient estimators of ELBO. The default implementation is correct both for non-reparameterized and for fully reparameterized distributions. Partially reparameterized distributions should override this method to compute correct `.score_function` and `.entropy_term` parts. Setting ``.has_rsample`` on a distribution instance will determine whether inference engines like :class:`~pyro.infer.svi.SVI` use reparameterized samplers or the score function estimator. :param torch.Tensor x: A single value or batch of values. :return: A `ScoreParts` object containing parts of the ELBO estimator. :rtype: ScoreParts """ log_prob = self.log_prob(x, *args, **kwargs) if self.has_rsample: return ScoreParts(log_prob=log_prob, score_function=0, entropy_term=log_prob) else: # XXX should the user be able to control inclusion of the entropy term? # See Roeder, Wu, Duvenaud (2017) "Sticking the Landing" https://arxiv.org/abs/1703.09194 return ScoreParts(log_prob=log_prob, score_function=log_prob, entropy_term=0)
def pack_tensors(self, plate_to_symbol=None): """ Computes packed representations of tensors in the trace. This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`. """ self.symbolize_dims(plate_to_symbol) for site in self.nodes.values(): if site["type"] != "sample": continue dim_to_symbol = site["infer"]["_dim_to_symbol"] packed = site.setdefault("packed", {}) try: packed["mask"] = pack(site["mask"], dim_to_symbol) if "score_parts" in site: log_prob, score_function, entropy_term = site["score_parts"] log_prob = pack(log_prob, dim_to_symbol) score_function = pack(score_function, dim_to_symbol) entropy_term = pack(entropy_term, dim_to_symbol) packed["score_parts"] = ScoreParts(log_prob, score_function, entropy_term) packed["log_prob"] = log_prob packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol) elif "log_prob" in site: packed["log_prob"] = pack(site["log_prob"], dim_to_symbol) packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol) except ValueError: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) raise ValueError("Error while packing tensors at site '{}':\n {}\n{}" .format(site["name"], exc_value, shapes)).with_traceback(traceback)
def score_parts(self, boosted_x=None): if boosted_x is None: boosted_x = self._unboost_x_cache[0] assert boosted_x is self._unboost_x_cache[0] x = self._unboost_x_cache[1] _, score_function, _ = self._rejection_gamma.score_parts(x) log_prob = self.log_prob(boosted_x) return ScoreParts(log_prob, score_function, log_prob)
def score_parts(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) log_prob, score_function, entropy_term = self.base_dist.score_parts(value) log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape) if not isinstance(score_function, numbers.Number): score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape) if not isinstance(entropy_term, numbers.Number): entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape) return ScoreParts(log_prob, score_function, entropy_term)
def score_parts(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) log_prob, score_function, entropy_term = self.base_dist.score_parts( value) if self.batch_shape != self.base_dist.batch_shape: log_prob = log_prob.expand(shape) if isinstance(score_function, torch.Tensor): score_function = score_function.expand(shape) if isinstance(score_function, torch.Tensor): entropy_term = entropy_term.expand(shape) return ScoreParts(log_prob, score_function, entropy_term)
def score_parts(self, x): score_function = self._log_prob_accept(x) log_prob = self.log_prob(x) return ScoreParts(log_prob, score_function, log_prob)
def score_parts(self, x): log_prob, score_function, _ = self._standard_gamma.score_parts( x * self.rate) log_prob = log_prob + torch.log(self.rate) return ScoreParts(log_prob, score_function, log_prob)