Ejemplo n.º 1
0
    def score_parts(self, x, *args, **kwargs):
        """
        Computes ingredients for stochastic gradient estimators of ELBO.

        The default implementation is correct both for non-reparameterized and
        for fully reparameterized distributions. Partially reparameterized
        distributions should override this method to compute correct
        `.score_function` and `.entropy_term` parts.

        Setting ``.has_rsample`` on a distribution instance will determine
        whether inference engines like :class:`~pyro.infer.svi.SVI` use
        reparameterized samplers or the score function estimator.

        :param torch.Tensor x: A single value or batch of values.
        :return: A `ScoreParts` object containing parts of the ELBO estimator.
        :rtype: ScoreParts
        """
        log_prob = self.log_prob(x, *args, **kwargs)
        if self.has_rsample:
            return ScoreParts(log_prob=log_prob,
                              score_function=0,
                              entropy_term=log_prob)
        else:
            # XXX should the user be able to control inclusion of the entropy term?
            # See Roeder, Wu, Duvenaud (2017) "Sticking the Landing" https://arxiv.org/abs/1703.09194
            return ScoreParts(log_prob=log_prob,
                              score_function=log_prob,
                              entropy_term=0)
Ejemplo n.º 2
0
 def pack_tensors(self, plate_to_symbol=None):
     """
     Computes packed representations of tensors in the trace.
     This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`.
     """
     self.symbolize_dims(plate_to_symbol)
     for site in self.nodes.values():
         if site["type"] != "sample":
             continue
         dim_to_symbol = site["infer"]["_dim_to_symbol"]
         packed = site.setdefault("packed", {})
         try:
             packed["mask"] = pack(site["mask"], dim_to_symbol)
             if "score_parts" in site:
                 log_prob, score_function, entropy_term = site["score_parts"]
                 log_prob = pack(log_prob, dim_to_symbol)
                 score_function = pack(score_function, dim_to_symbol)
                 entropy_term = pack(entropy_term, dim_to_symbol)
                 packed["score_parts"] = ScoreParts(log_prob, score_function, entropy_term)
                 packed["log_prob"] = log_prob
                 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
             elif "log_prob" in site:
                 packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
                 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
         except ValueError:
             _, exc_value, traceback = sys.exc_info()
             shapes = self.format_shapes(last_site=site["name"])
             raise ValueError("Error while packing tensors at site '{}':\n  {}\n{}"
                              .format(site["name"], exc_value, shapes)).with_traceback(traceback)
Ejemplo n.º 3
0
 def score_parts(self, boosted_x=None):
     if boosted_x is None:
         boosted_x = self._unboost_x_cache[0]
     assert boosted_x is self._unboost_x_cache[0]
     x = self._unboost_x_cache[1]
     _, score_function, _ = self._rejection_gamma.score_parts(x)
     log_prob = self.log_prob(boosted_x)
     return ScoreParts(log_prob, score_function, log_prob)
Ejemplo n.º 4
0
 def score_parts(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
     log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(score_function, numbers.Number):
         score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(entropy_term, numbers.Number):
         entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape)
     return ScoreParts(log_prob, score_function, entropy_term)
Ejemplo n.º 5
0
 def score_parts(self, value):
     shape = broadcast_shape(self.batch_shape,
                             value.shape[:value.dim() - self.event_dim])
     log_prob, score_function, entropy_term = self.base_dist.score_parts(
         value)
     if self.batch_shape != self.base_dist.batch_shape:
         log_prob = log_prob.expand(shape)
         if isinstance(score_function, torch.Tensor):
             score_function = score_function.expand(shape)
         if isinstance(score_function, torch.Tensor):
             entropy_term = entropy_term.expand(shape)
     return ScoreParts(log_prob, score_function, entropy_term)
Ejemplo n.º 6
0
 def score_parts(self, x):
     score_function = self._log_prob_accept(x)
     log_prob = self.log_prob(x)
     return ScoreParts(log_prob, score_function, log_prob)
Ejemplo n.º 7
0
 def score_parts(self, x):
     log_prob, score_function, _ = self._standard_gamma.score_parts(
         x * self.rate)
     log_prob = log_prob + torch.log(self.rate)
     return ScoreParts(log_prob, score_function, log_prob)