Exemplo n.º 1
0
    def condition(self, value):
        """
        Condition this Gaussian on a trailing subset of its state.
        This should satisfy::

            g.condition(y).dim() == g.dim() - y.size(-1)

        Note that since this is a non-normalized Gaussian, we include the
        density of ``y`` in the result. Thus :meth:`condition` is similar to a
        ``functools.partial`` binding of arguments::

            left = x[..., :n]
            right = x[..., n:]
            g.log_density(x) == g.condition(right).log_density(left)
        """
        assert isinstance(value, torch.Tensor)
        right = value.size(-1)
        dim = self.dim()
        assert right <= dim

        n = dim - right
        info_a = self.info_vec[..., :n]
        info_b = self.info_vec[..., n:]
        P_aa = self.precision[..., :n, :n]
        P_ab = self.precision[..., :n, n:]
        P_bb = self.precision[..., n:, n:]
        b = value

        info_vec = info_a - matvecmul(P_ab, b)
        precision = P_aa
        log_normalizer = (self.log_normalizer +
                          -0.5 * matvecmul(P_bb, b).mul(b).sum(-1) +
                          b.mul(info_b).sum(-1))
        return Gaussian(log_normalizer, info_vec, precision)
Exemplo n.º 2
0
def mvn_to_gaussian(mvn):
    """
    Convert a MultivariateNormal distribution to a Gaussian.

    :param ~torch.distributions.MultivariateNormal mvn: A multivariate normal distribution.
    :return: An equivalent Gaussian object.
    :rtype: ~pyro.ops.gaussian.Gaussian
    """
    assert (isinstance(mvn, torch.distributions.MultivariateNormal) or
            (isinstance(mvn, torch.distributions.Independent) and
             isinstance(mvn.base_dist, torch.distributions.Normal)))
    if isinstance(mvn, torch.distributions.Independent):
        mvn = mvn.base_dist
        precision_diag = mvn.scale.pow(-2)
        precision = precision_diag.diag_embed()
        info_vec = mvn.loc * precision_diag
        scale_diag = mvn.scale
    else:
        precision = mvn.precision_matrix
        info_vec = matvecmul(precision, mvn.loc)
        scale_diag = mvn.scale_tril.diagonal(dim1=-2, dim2=-1)

    n = mvn.loc.size(-1)
    log_normalizer = (-0.5 * n * math.log(2 * math.pi) +
                      -0.5 * (info_vec * mvn.loc).sum(-1) -
                      scale_diag.log().sum(-1))
    return Gaussian(log_normalizer, info_vec, precision)
Exemplo n.º 3
0
 def condition(self, value):
     if value.size(-1) == self.loc.size(-1):
         prec_sqrt = self.matrix / self.scale.unsqueeze(-2)
         precision = matmul(prec_sqrt, prec_sqrt.transpose(-1, -2))
         delta = (value - self.loc) / self.scale
         info_vec = matvecmul(prec_sqrt, delta)
         log_normalizer = (-0.5 * self.loc.size(-1) * math.log(2 * math.pi)
                           - 0.5 * delta.pow(2).sum(-1) - self.scale.log().sum(-1))
         return Gaussian(log_normalizer, info_vec, precision)
     else:
         return self.to_gaussian().condition(value)
Exemplo n.º 4
0
 def left_condition(self, value):
     """
     If ``value.size(-1) == x_dim``, this returns a Normal distribution with
     ``event_dim=1``. After applying this method, the cost to draw a sample is
     ``O(y_dim)`` instead of ``O(y_dim ** 3)``.
     """
     if value.size(-1) == self.matrix.size(-2):
         loc = matvecmul(self.matrix.transpose(-1, -2), value) + self.loc
         matrix = value.new_zeros(loc.shape[:-1] + (0, loc.size(-1)))
         scale = self.scale.expand(loc.shape)
         return AffineNormal(matrix, loc, scale)
     else:
         return self.to_gaussian().left_condition(value)
Exemplo n.º 5
0
def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian):
    P_yy = y_gaussian.precision
    neg_P_xy = matmul(matrix, P_yy)
    P_xy = -neg_P_xy
    P_yx = P_xy.transpose(-1, -2)
    P_xx = matmul(neg_P_xy, matrix.transpose(-1, -2))
    precision = torch.cat([torch.cat([P_xx, P_xy], -1),
                           torch.cat([P_yx, P_yy], -1)], -2)
    info_y = y_gaussian.info_vec
    info_x = -matvecmul(matrix, info_y)
    info_vec = torch.cat([info_x, info_y], -1)
    log_normalizer = y_gaussian.log_normalizer

    result = Gaussian(log_normalizer, info_vec, precision)
    return result
Exemplo n.º 6
0
    def log_density(self, value):
        """
        Evaluate the log density of this Gaussian at a point value::

            -0.5 * value.T @ precision @ value + value.T @ info_vec + log_normalizer

        This is mainly used for testing.
        """
        if value.size(-1) == 0:
            batch_shape = broadcast_shape(value.shape[:-1], self.batch_shape)
            return self.log_normalizer.expand(batch_shape)
        result = (-0.5) * matvecmul(self.precision, value)
        result = result + self.info_vec
        result = (value * result).sum(-1)
        return result + self.log_normalizer