def compute_score_parts(self): """ Compute the batched local score parts at each site of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. Each ``log_prob_sum`` is a scalar. All computations are memoized. """ for name, site in self.nodes.items(): if site["type"] == "sample" and "score_parts" not in site: # Note that ScoreParts overloads the multiplication operator # to correctly scale each of its three parts. try: value = site["fn"].score_parts(site["value"], *site["args"], **site["kwargs"]) except ValueError: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) raise ValueError( "Error while computing score_parts at site '{}':\n{}\n{}" .format(name, exc_value, shapes)).with_traceback(traceback) site["unscaled_log_prob"] = value.log_prob value = value.scale_and_mask(site["scale"], site["mask"]) site["score_parts"] = value site["log_prob"] = value.log_prob site["log_prob_sum"] = value.log_prob.sum() if is_validation_enabled(): warn_if_nan(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)) warn_if_inf(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name), allow_neginf=True)
def log_prob_sum(self, site_filter=lambda name, site: True): """ Compute the site-wise log probabilities of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. Each ``log_prob_sum`` is a scalar. The computation of ``log_prob_sum`` is memoized. :returns: total log probability. :rtype: torch.Tensor """ result = 0.0 for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): if "log_prob_sum" in site: log_p = site["log_prob_sum"] else: try: log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"]) except ValueError: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) raise ValueError("Error while computing log_prob_sum at site '{}':\n{}\n" .format(name, exc_value, shapes)).with_traceback(traceback) log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum() site["log_prob_sum"] = log_p if is_validation_enabled(): warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name)) warn_if_inf(log_p, "log_prob_sum at site '{}'".format(name), allow_neginf=True) result = result + log_p return result
def compute_log_prob(self, site_filter=lambda name, site: True): """ Compute the site-wise log probabilities of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. Each ``log_prob_sum`` is a scalar. Both computations are memoized. """ for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): if "log_prob" not in site: try: log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"]) except ValueError: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) raise ValueError( "Error while computing log_prob at site '{}':\n{}\n{}" .format(name, exc_value, shapes)).with_traceback(traceback) site["unscaled_log_prob"] = log_p log_p = scale_and_mask(log_p, site["scale"], site["mask"]) site["log_prob"] = log_p site["log_prob_sum"] = log_p.sum() if is_validation_enabled(): warn_if_nan(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)) warn_if_inf(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name), allow_neginf=True)
def test_warn_if_inf(): # scalar case with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") x = 3 msg = "example message" y = util.warn_if_inf(x, msg, allow_posinf=True, allow_neginf=True) assert y is x assert len(w) == 0 x = float("inf") util.warn_if_inf(x, msg, allow_posinf=True) assert len(w) == 0 util.warn_if_inf(x, msg, allow_neginf=True) assert len(w) == 1 assert msg in str(w[-1].message) # tensor case with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") x = torch.ones(2) util.warn_if_inf(x, msg, allow_posinf=True, allow_neginf=True) assert len(w) == 0 x[0] = float("inf") util.warn_if_inf(x, msg, allow_posinf=True) assert len(w) == 0 util.warn_if_inf(x, msg, allow_neginf=True) assert len(w) == 1 assert msg in str(w[-1].message) # grad case with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") x = torch.ones(2, requires_grad=True) util.warn_if_inf(x, msg, allow_posinf=True) y = x.sum() y.backward([torch.tensor(float("inf"))]) assert len(w) == 0 x.grad = None y.backward([torch.tensor(-float("inf"))]) assert len(w) == 1 assert msg in str(w[-1].message) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") z = torch.ones(2, requires_grad=True) y = z.sum() util.warn_if_inf(z, msg, allow_neginf=True) y.backward([torch.tensor(-float("inf"))]) assert len(w) == 0 z.grad = None y.backward([torch.tensor(float("inf"))]) assert len(w) == 1 assert msg in str(w[-1].message)