Пример #1
0
def _squared_error(x, y, scale, mask):
    diff = x - y
    if getattr(scale, "shape", ()) or getattr(mask, "shape", ()):
        error = torch.einsum("nbe,nbe->nb", diff, diff)
        return scale_and_mask(error, scale, mask).sum(-1)
    else:
        error = torch.einsum("nbe,nbe->n", diff, diff)
        return scale_and_mask(error, scale, mask)
Пример #2
0
    def scale_and_mask(self, scale=1.0, mask=None):
        """
        Scale and mask appropriate terms of a gradient estimator by a data multiplicity factor.
        Note that the `score_function` term should not be scaled or masked.

        :param scale: a positive scale
        :type scale: torch.Tensor or number
        :param mask: an optional masking tensor
        :type mask: torch.ByteTensor or None
        """
        log_prob = scale_and_mask(self.log_prob, scale, mask)
        score_function = self.score_function  # not scaled
        entropy_term = scale_and_mask(self.entropy_term, scale, mask)
        return ScoreParts(log_prob, score_function, entropy_term)
Пример #3
0
 def compute_log_prob(self, site_filter=lambda name, site: True):
     """
     Compute the site-wise log probabilities of the trace.
     Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
     Each ``log_prob_sum`` is a scalar.
     Both computations are memoized.
     """
     for name, site in self.nodes.items():
         if site["type"] == "sample" and site_filter(name, site):
             if "log_prob" not in site:
                 try:
                     log_p = site["fn"].log_prob(site["value"],
                                                 *site["args"],
                                                 **site["kwargs"])
                 except ValueError:
                     _, exc_value, traceback = sys.exc_info()
                     shapes = self.format_shapes(last_site=site["name"])
                     raise ValueError(
                         "Error while computing log_prob at site '{}':\n{}\n{}"
                         .format(name, exc_value,
                                 shapes)).with_traceback(traceback)
                 site["unscaled_log_prob"] = log_p
                 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
                 site["log_prob"] = log_p
                 site["log_prob_sum"] = log_p.sum()
                 if is_validation_enabled():
                     warn_if_nan(site["log_prob_sum"],
                                 "log_prob_sum at site '{}'".format(name))
                     warn_if_inf(site["log_prob_sum"],
                                 "log_prob_sum at site '{}'".format(name),
                                 allow_neginf=True)
Пример #4
0
    def log_prob_sum(self, site_filter=lambda name, site: True):
        """
        Compute the site-wise log probabilities of the trace.
        Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
        Each ``log_prob_sum`` is a scalar.
        The computation of ``log_prob_sum`` is memoized.

        :returns: total log probability.
        :rtype: torch.Tensor
        """
        result = 0.0
        for name, site in self.nodes.items():
            if site["type"] == "sample" and site_filter(name, site):
                if "log_prob_sum" in site:
                    log_p = site["log_prob_sum"]
                else:
                    try:
                        log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
                    except ValueError:
                        _, exc_value, traceback = sys.exc_info()
                        shapes = self.format_shapes(last_site=site["name"])
                        raise ValueError("Error while computing log_prob_sum at site '{}':\n{}\n"
                                         .format(name, exc_value, shapes)).with_traceback(traceback)
                    log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
                    site["log_prob_sum"] = log_p
                    if is_validation_enabled():
                        warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name))
                        warn_if_inf(log_p, "log_prob_sum at site '{}'".format(name), allow_neginf=True)
                result = result + log_p
        return result
Пример #5
0
    def _differentiable_loss_particle(self, model_trace, guide_trace):
        elbo_particle = 0

        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    elbo_particle = elbo_particle + model_site["log_prob_sum"]
                else:
                    guide_site = guide_trace.nodes[name]
                    if is_validation_enabled():
                        check_fully_reparametrized(guide_site)

                    # use kl divergence if available, else fall back on sampling
                    try:
                        kl_qp = kl_divergence(guide_site["fn"], model_site["fn"])
                        kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"], mask=guide_site["mask"])
                        assert kl_qp.shape == guide_site["fn"].batch_shape
                        elbo_particle = elbo_particle - kl_qp.sum()
                    except NotImplementedError:
                        entropy_term = guide_site["score_parts"].entropy_term
                        elbo_particle = elbo_particle + model_site["log_prob_sum"] - entropy_term.sum()

        # handle auxiliary sites in the guide
        for name, guide_site in guide_trace.nodes.items():
            if guide_site["type"] == "sample" and name not in model_trace.nodes:
                assert guide_site["infer"].get("is_auxiliary")
                if is_validation_enabled():
                    check_fully_reparametrized(guide_site)
                entropy_term = guide_site["score_parts"].entropy_term
                elbo_particle = elbo_particle - entropy_term.sum()

        loss = -(elbo_particle.detach() if torch._C._get_tracing_state() else torch_item(elbo_particle))
        surrogate_loss = -elbo_particle
        return loss, surrogate_loss
Пример #6
0
def test_mask(batch_dim, event_dim, mask_dim):
    # Construct base distribution.
    shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
    batch_shape = shape[:batch_dim]
    mask_shape = batch_shape[batch_dim - mask_dim:]
    base_dist = Bernoulli(0.1).expand_by(shape).to_event(event_dim)

    # Construct masked distribution.
    mask = checker_mask(mask_shape)
    dist = base_dist.mask(mask)

    # Check shape.
    sample = base_dist.sample()
    assert dist.batch_shape == base_dist.batch_shape
    assert dist.event_shape == base_dist.event_shape
    assert sample.shape == sample.shape
    assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape

    # Check values.
    assert_equal(dist.mean, base_dist.mean)
    assert_equal(dist.variance, base_dist.variance)
    assert_equal(dist.log_prob(sample),
                 scale_and_mask(base_dist.log_prob(sample), mask=mask))
    assert_equal(dist.score_parts(sample),
                 base_dist.score_parts(sample).scale_and_mask(mask=mask),
                 prec=0)
    if not dist.event_shape:
        assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
        assert_equal(dist.enumerate_support(expand=True),
                     base_dist.enumerate_support(expand=True))
        assert_equal(dist.enumerate_support(expand=False),
                     base_dist.enumerate_support(expand=False))
Пример #7
0
 def log_prob(self, value):
     if self._mask is False:
         shape = broadcast_shape(self.base_dist.batch_shape,
                                 value.shape[:value.dim() - self.event_dim])
         return torch.zeros((), device=value.device).expand(shape)
     if self._mask is True:
         return self.base_dist.log_prob(value)
     return scale_and_mask(self.base_dist.log_prob(value), mask=self._mask)
Пример #8
0
def _kl_masked_masked(p, q):
    if p._mask is False or q._mask is False:
        mask = False
    elif p._mask is True:
        mask = q._mask
    elif q._mask is True:
        mask = p._mask
    elif p._mask is q._mask:
        mask = p._mask
    else:
        mask = p._mask & q._mask

    if mask is False:
        return 0.  # Return a float, since we cannot determine device.
    if mask is True:
        return kl_divergence(p.base_dist, q.base_dist)
    kl = kl_divergence(p.base_dist, q.base_dist)
    return scale_and_mask(kl, mask=mask)
Пример #9
0
 def f(tensor, scale, mask):
     return scale_and_mask(tensor, scale=scale, mask=mask)
Пример #10
0
def _kl_masked_masked(p, q):
    mask = p._mask if p._mask is q._mask else p._mask & q._mask
    kl = kl_divergence(p.base_dist, q.base_dist)
    return scale_and_mask(kl, mask=mask)
Пример #11
0
 def log_prob(self, value):
     return scale_and_mask(self.base_dist.log_prob(value), mask=self._mask)
Пример #12
0
    def f(tensor, scale, mask): return scale_and_mask(tensor, scale=scale, mask=mask)

    x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')])