Exemple #1
0
 def log_abs_det_jacobian(self, x, y):
     if not self.parts:
         return torch.zeros_like(x)
     result = 0
     for part in self.parts[:-1]:
         y_tmp = part(x)
         result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp),
                                          self.event_dim - part.event_dim)
         x = y_tmp
     part = self.parts[-1]
     result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
                                      self.event_dim - part.event_dim)
     return result
Exemple #2
0
    def _log_prob(self, y, transforms):
        # TODO: fix dtypes
        event_dim = len(self.event_shape)
        assert torch.isnan(y).sum() == 0
        assert (y.abs() == float("inf")).any() == 0
        if not transforms:
            log_prob = _sum_rightmost(
                self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
            ).float()
            assert torch.isnan(log_prob).sum() == 0
            return log_prob

        transform, *transforms = transforms

        if isinstance(transform, Transform):
            x = transform.inv(y)
            log_prob = -_sum_rightmost(
                transform.log_abs_det_jacobian(x, y), event_dim - transform.event_dim
            )
            next_log_prob = self._log_prob(x, transforms)
            assert torch.isnan(log_prob).sum() == 0
            assert torch.isnan(next_log_prob).sum() == 0
            sum_log_prob = log_prob.float() + next_log_prob.float()
            assert torch.isnan(sum_log_prob).sum() == 0
            return sum_log_prob
        else:
            x, xset, mask = transform.inverse_set(y)

            # First propate back x to use caching
            x_log_prob = -_sum_rightmost(
                transform.log_abs_det_jacobian(x, y), event_dim - transform.event_dim
            )
            x_next_log_prob = self._log_prob(x, transforms)
            x_term = x_log_prob.float() + x_next_log_prob.float()

            # Now propagate others
            xset_log_prob = -_sum_rightmost(
                transform.log_abs_det_jacobian(xset, y), event_dim - transform.event_dim
            )
            xset_next_log_prob = self._log_prob(xset, transforms)
            xset_terms = torch.where(
                mask,
                xset_log_prob.float() + xset_next_log_prob.float(),
                torch.tensor([float("-inf")], device=xset_log_prob.device),
            )

            terms = torch.cat([x_term[None], xset_terms])
            assert torch.isnan(terms).sum() == 0
            return torch.logsumexp(terms, dim=0)
Exemple #3
0
 def log_abs_det_jacobian(self, x, y):
     assert -x.dim() <= self.dim < x.dim()
     assert x.size(self.dim) == self.length
     assert -y.dim() <= self.dim < y.dim()
     assert y.size(self.dim) == self.length
     logdetjacs = []
     start = 0
     for trans, length in zip(self.transforms, self.lengths):
         xslice = x.narrow(self.dim, start, length)
         yslice = y.narrow(self.dim, start, length)
         logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
         if trans.event_dim < self.event_dim:
             logdetjac = _sum_rightmost(logdetjac,
                                        self.event_dim - trans.event_dim)
         logdetjacs.append(logdetjac)
         start = start + length  # avoid += for jit compat
     # Decide whether to concatenate or sum.
     dim = self.dim
     if dim >= 0:
         dim = dim - x.dim()
     dim = dim + self.event_dim
     if dim < 0:
         return torch.cat(logdetjacs, dim=dim)
     else:
         return sum(logdetjacs)
Exemple #4
0
def test_independent(base_dist, sample_shape, batch_shape,
                     reinterpreted_batch_ndims):
    if batch_shape:
        base_dist = base_dist.expand_by(batch_shape)
    if reinterpreted_batch_ndims > len(base_dist.batch_shape):
        with pytest.raises(ValueError):
            d = dist.Independent(base_dist, reinterpreted_batch_ndims)
    else:
        d = dist.Independent(base_dist, reinterpreted_batch_ndims)
        assert (d.batch_shape == batch_shape[:len(batch_shape) -
                                             reinterpreted_batch_ndims])
        assert (d.event_shape == batch_shape[len(batch_shape) -
                                             reinterpreted_batch_ndims:] +
                base_dist.event_shape)

        assert d.sample().shape == batch_shape + base_dist.event_shape
        assert d.mean.shape == batch_shape + base_dist.event_shape
        assert d.variance.shape == batch_shape + base_dist.event_shape
        x = d.sample(sample_shape)
        assert x.shape == sample_shape + d.batch_shape + d.event_shape

        log_prob = d.log_prob(x)
        assert (log_prob.shape == sample_shape +
                batch_shape[:len(batch_shape) - reinterpreted_batch_ndims])
        assert not torch_isnan(log_prob)
        log_prob_0 = base_dist.log_prob(x)
        assert_equal(log_prob,
                     _sum_rightmost(log_prob_0, reinterpreted_batch_ndims))
Exemple #5
0
 def log_prob(self, value):
     """
     Scores the sample by inverting the transform(s) and computing the score using the score
     of the base distribution and the log abs det jacobian
     """
     event_dim = len(self.event_shape)
     log_prob = 0.0
     y = value
     for transform in reversed(self.transforms):
         x = transform.inv(y)
         log_prob -= _sum_rightmost(transform.log_abs_det_jacobian(x, y),
                                    event_dim - transform.event_dim)
         y = x
     log_prob += _sum_rightmost(self.base_dist.log_prob(y),
                                event_dim - len(self.base_dist.event_shape))
     return log_prob
Exemple #6
0
def _kl_transformed_transformed(p, q):
    if p.transforms != q.transforms:
        raise NotImplementedError
    if p.event_shape != q.event_shape:
        raise NotImplementedError
    extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)
    base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
    return _sum_rightmost(base_kl_divergence, extra_event_dim)
 def log_prob(self, value):
     """
     Scores the sample by inverting the transform(s) and computing the score
     using the score of the base distribution and the log abs det jacobian.
     """
     self.base_dist._validate_log_prob_arg(value)
     event_dim = len(self.event_shape)
     log_prob = 0.0
     y = value
     for transform in reversed(self.transforms):
         x = transform.inv(y)
         log_prob -= _sum_rightmost(transform.log_abs_det_jacobian(x, y),
                                    event_dim - transform.event_dim)
         y = x
     log_prob += _sum_rightmost(self.base_dist.log_prob(y),
                                event_dim - len(self.base_dist.event_shape))
     return log_prob
Exemple #8
0
 def log_abs_det_jacobian(self, x, y):
     if not self.parts:
         return x.new([0]).expand_as(x)
     result = 0
     for part in self.parts:
         y = part(x)
         result += _sum_rightmost(part.log_abs_det_jacobian(x, y),
                                  self.event_dim - part.event_dim)
         x = y
     return result
Exemple #9
0
 def log_abs_det_jacobian(self, x, y):
     if not self.parts:
         return x.new([0]).expand_as(x)
     result = 0
     for part in self.parts:
         y = part(x)
         result += _sum_rightmost(part.log_abs_det_jacobian(x, y),
                                  self.event_dim - part.event_dim)
         x = y
     return result
Exemple #10
0
 def log_abs_det_jacobian(self, x, y):
     if not self.parts:
         return torch.zeros_like(x)
     result = 0
     for part in self.parts:
         y = part(x)
         result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
                                          self.event_dim - part.event_dim)
         x = y
     return result
    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
                                                 event_dim - transform.domain.event_dim)
            y = x

        log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
                                             event_dim - len(self.base_dist.event_shape))
        return log_prob
    def log_prob(
        self, y: torch.Tensor, context: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if context is None:
            context = self._context
        event_dim = len(self.event_shape)

        x = self.bijector.inverse(y, self.params(), context)
        log_prob = -_sum_rightmost(
            self.bijector.log_abs_det_jacobian(x, y, self.params(), context),
            event_dim - self.bijector.event_dim,
        )
        log_prob = log_prob + _sum_rightmost(
            self.base_dist.log_prob(x),
            event_dim - len(self.base_dist.event_shape),
        )

        return log_prob
Exemple #13
0
 def log_abs_det_jacobian(self, x, y):
     """
     Calculates the elementwise determinant of the log jacobian
     """
     x_old, y_old = self._cached_x_y
     if self._cached_log_scale is not None and x is x_old and y is y_old:
         log_scale = self._cached_log_scale
     else:
         x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim)
         _, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,)))
         log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[-self.event_dim:])
         log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
     return _sum_rightmost(log_scale, self.event_dim)
Exemple #14
0
 def log_abs_det_jacobian(self, x, y, params=None, context=None):
     """
     Computes the log det jacobian `log |dy/dx|` given input and output.
     By default, assumes a volume preserving bijection.
     """
     ldj = _sum_rightmost(
         torch.zeros_like(y),
         self.event_dim,
     )
     for bijector, param in zip(reversed(self.bijectors), reversed(params)):
         y_inv = bijector.inverse(y, param, context)
         ldj += bijector.log_abs_det_jacobian(y_inv, y, param, context)
         y = y_inv
     return ldj
Exemple #15
0
 def log_abs_det_jacobian(self, x, y):
     assert -x.dim() <= self.dim < x.dim()
     assert x.size(self.dim) == self.length
     assert -y.dim() <= self.dim < y.dim()
     assert y.size(self.dim) == self.length
     logdetjacs = []
     start = 0
     for trans, length in zip(self.transforms, self.lengths):
         xslice = x.narrow(self.dim, start, length)
         yslice = y.narrow(self.dim, start, length)
         logdetjacs.append(
             trans.log_abs_det_jacobian(
                 xslice, yslice).reshape(x.shape + (1, ) * self.event_dim))
         start = start + length  # avoid += for jit compat
     return _sum_rightmost(torch.cat(logdetjacs, dim=self.dim),
                           self.event_dim)
Exemple #16
0
    def log_abs_det_jacobian(self, x, y):
        if not self.parts:
            return torch.zeros_like(x)

        # Compute intermediates. This will be free if parts[:-1] are all cached.
        xs = [x]
        for part in self.parts[:-1]:
            xs.append(part(xs[-1]))
        xs.append(y)

        terms = []
        event_dim = self.domain.event_dim
        for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
            terms.append(
                _sum_rightmost(part.log_abs_det_jacobian(x, y),
                               event_dim - part.domain.event_dim))
            event_dim += part.codomain.event_dim - part.domain.event_dim
        return functools.reduce(operator.add, terms)
Exemple #17
0
 def log_prob(self, value):
     log_prob = self.base_dist.log_prob(value)
     return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
 def entropy(self):
     entropy = self.base_dist.entropy()
     return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
 def log_prob(self, value):
     log_prob = self.base_dist.log_prob(value)
     return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
def _kl_independent_independent(p, q):
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
        raise NotImplementedError
    result = kl_divergence(p.base_dist, q.base_dist)
    return _sum_rightmost(result, p.reinterpreted_batch_ndims)
Exemple #21
0
 def entropy(self):
     entropy = self.base_dist.entropy()
     return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
Exemple #22
0
 def log_abs_det_jacobian(self, x, y):
     result = self.base_transform.log_abs_det_jacobian(x, y)
     result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
     return result