Пример #1
0
 def compute_score_parts(self):
     """
     Compute the batched local score parts at each site of the trace.
     Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
     Each ``log_prob_sum`` is a scalar.
     All computations are memoized.
     """
     for name, site in self.nodes.items():
         if site["type"] == "sample" and "score_parts" not in site:
             # Note that ScoreParts overloads the multiplication operator
             # to correctly scale each of its three parts.
             try:
                 value = site["fn"].score_parts(site["value"],
                                                *site["args"],
                                                **site["kwargs"])
             except ValueError:
                 _, exc_value, traceback = sys.exc_info()
                 shapes = self.format_shapes(last_site=site["name"])
                 raise ValueError(
                     "Error while computing score_parts at site '{}':\n{}\n{}"
                     .format(name, exc_value,
                             shapes)).with_traceback(traceback)
             site["unscaled_log_prob"] = value.log_prob
             value = value.scale_and_mask(site["scale"], site["mask"])
             site["score_parts"] = value
             site["log_prob"] = value.log_prob
             site["log_prob_sum"] = value.log_prob.sum()
             if is_validation_enabled():
                 warn_if_nan(site["log_prob_sum"],
                             "log_prob_sum at site '{}'".format(name))
                 warn_if_inf(site["log_prob_sum"],
                             "log_prob_sum at site '{}'".format(name),
                             allow_neginf=True)
Пример #2
0
    def log_prob_sum(self, site_filter=lambda name, site: True):
        """
        Compute the site-wise log probabilities of the trace.
        Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
        Each ``log_prob_sum`` is a scalar.
        The computation of ``log_prob_sum`` is memoized.

        :returns: total log probability.
        :rtype: torch.Tensor
        """
        result = 0.0
        for name, site in self.nodes.items():
            if site["type"] == "sample" and site_filter(name, site):
                if "log_prob_sum" in site:
                    log_p = site["log_prob_sum"]
                else:
                    try:
                        log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
                    except ValueError:
                        _, exc_value, traceback = sys.exc_info()
                        shapes = self.format_shapes(last_site=site["name"])
                        raise ValueError("Error while computing log_prob_sum at site '{}':\n{}\n"
                                         .format(name, exc_value, shapes)).with_traceback(traceback)
                    log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
                    site["log_prob_sum"] = log_p
                    if is_validation_enabled():
                        warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name))
                        warn_if_inf(log_p, "log_prob_sum at site '{}'".format(name), allow_neginf=True)
                result = result + log_p
        return result
Пример #3
0
 def compute_log_prob(self, site_filter=lambda name, site: True):
     """
     Compute the site-wise log probabilities of the trace.
     Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
     Each ``log_prob_sum`` is a scalar.
     Both computations are memoized.
     """
     for name, site in self.nodes.items():
         if site["type"] == "sample" and site_filter(name, site):
             if "log_prob" not in site:
                 try:
                     log_p = site["fn"].log_prob(site["value"],
                                                 *site["args"],
                                                 **site["kwargs"])
                 except ValueError:
                     _, exc_value, traceback = sys.exc_info()
                     shapes = self.format_shapes(last_site=site["name"])
                     raise ValueError(
                         "Error while computing log_prob at site '{}':\n{}\n{}"
                         .format(name, exc_value,
                                 shapes)).with_traceback(traceback)
                 site["unscaled_log_prob"] = log_p
                 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
                 site["log_prob"] = log_p
                 site["log_prob_sum"] = log_p.sum()
                 if is_validation_enabled():
                     warn_if_nan(site["log_prob_sum"],
                                 "log_prob_sum at site '{}'".format(name))
                     warn_if_inf(site["log_prob_sum"],
                                 "log_prob_sum at site '{}'".format(name),
                                 allow_neginf=True)
Пример #4
0
def test_warn_if_inf():
    # scalar case
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        x = 3
        msg = "example message"
        y = util.warn_if_inf(x, msg, allow_posinf=True, allow_neginf=True)
        assert y is x
        assert len(w) == 0
        x = float("inf")
        util.warn_if_inf(x, msg, allow_posinf=True)
        assert len(w) == 0
        util.warn_if_inf(x, msg, allow_neginf=True)
        assert len(w) == 1
        assert msg in str(w[-1].message)

    # tensor case
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        x = torch.ones(2)
        util.warn_if_inf(x, msg, allow_posinf=True, allow_neginf=True)
        assert len(w) == 0
        x[0] = float("inf")
        util.warn_if_inf(x, msg, allow_posinf=True)
        assert len(w) == 0
        util.warn_if_inf(x, msg, allow_neginf=True)
        assert len(w) == 1
        assert msg in str(w[-1].message)

    # grad case
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        x = torch.ones(2, requires_grad=True)
        util.warn_if_inf(x, msg, allow_posinf=True)
        y = x.sum()
        y.backward([torch.tensor(float("inf"))])
        assert len(w) == 0

        x.grad = None
        y.backward([torch.tensor(-float("inf"))])
        assert len(w) == 1
        assert msg in str(w[-1].message)

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        z = torch.ones(2, requires_grad=True)
        y = z.sum()
        util.warn_if_inf(z, msg, allow_neginf=True)
        y.backward([torch.tensor(-float("inf"))])
        assert len(w) == 0
        z.grad = None
        y.backward([torch.tensor(float("inf"))])
        assert len(w) == 1
        assert msg in str(w[-1].message)