Exemple #1
0
def gaussian_tensordot(x, y, dims=0):
    """
    Computes the integral over two gaussians:

        `(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b))`,

    where `x` is a gaussian over variables (a,b), `y` is a gaussian over variables
    (b,c), (a,b,c) can each be sets of zero or more variables, and `dims` is the size of b.

    :param x: a Gaussian instance
    :param y: a Gaussian instance
    :param dims: number of variables to contract
    """
    assert isinstance(x, Gaussian)
    assert isinstance(y, Gaussian)
    na = x.dim() - dims
    nb = dims
    nc = y.dim() - dims
    assert na >= 0
    assert nb >= 0
    assert nc >= 0

    Paa, Pba, Pbb = (
        x.precision[..., :na, :na],
        x.precision[..., na:, :na],
        x.precision[..., na:, na:],
    )
    Qbb, Qbc, Qcc = (
        y.precision[..., :nb, :nb],
        y.precision[..., :nb, nb:],
        y.precision[..., nb:, nb:],
    )
    xa, xb = x.info_vec[..., :na], x.info_vec[..., na:]  # x.precision @ x.mean
    yb, yc = y.info_vec[..., :nb], y.info_vec[..., nb:]  # y.precision @ y.mean

    precision = pad(Paa, (0, nc, 0, nc)) + pad(Qcc, (na, 0, na, 0))
    info_vec = pad(xa, (0, nc)) + pad(yc, (na, 0))
    log_normalizer = x.log_normalizer + y.log_normalizer
    if nb > 0:
        B = pad(Pba, (0, nc)) + pad(Qbc, (na, 0))
        b = xb + yb

        # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral)
        L = cholesky(Pbb + Qbb)
        LinvB = triangular_solve(B, L, upper=False)
        LinvBt = LinvB.transpose(-2, -1)
        Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False)

        precision = precision - matmul(LinvBt, LinvB)
        # NB: precision might not be invertible for getting mean = precision^-1 @ info_vec
        if na + nc > 0:
            info_vec = info_vec - matmul(LinvBt, Linvb).squeeze(-1)
        logdet = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1)
        diff = (0.5 * nb * math.log(2 * math.pi) +
                0.5 * Linvb.squeeze(-1).pow(2).sum(-1) - logdet)
        log_normalizer = log_normalizer + diff

    return Gaussian(log_normalizer, info_vec, precision)
Exemple #2
0
def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian):
    P_yy = y_gaussian.precision
    neg_P_xy = matmul(matrix, P_yy)
    P_xy = -neg_P_xy
    P_yx = P_xy.transpose(-1, -2)
    P_xx = matmul(neg_P_xy, matrix.transpose(-1, -2))
    precision = torch.cat([torch.cat([P_xx, P_xy], -1),
                           torch.cat([P_yx, P_yy], -1)], -2)
    info_y = y_gaussian.info_vec
    info_x = -matvecmul(matrix, info_y)
    info_vec = torch.cat([info_x, info_y], -1)
    log_normalizer = y_gaussian.log_normalizer

    result = Gaussian(log_normalizer, info_vec, precision)
    return result
Exemple #3
0
 def condition(self, value):
     if value.size(-1) == self.loc.size(-1):
         prec_sqrt = self.matrix / self.scale.unsqueeze(-2)
         precision = matmul(prec_sqrt, prec_sqrt.transpose(-1, -2))
         delta = (value - self.loc) / self.scale
         info_vec = matvecmul(prec_sqrt, delta)
         log_normalizer = (-0.5 * self.loc.size(-1) * math.log(2 * math.pi)
                           - 0.5 * delta.pow(2).sum(-1) - self.scale.log().sum(-1))
         return Gaussian(log_normalizer, info_vec, precision)
     else:
         return self.to_gaussian().condition(value)
Exemple #4
0
    def marginalize(self, left=0, right=0):
        """
        Marginalizing out variables on either side of the event dimension::

            g.marginalize(left=n).event_logsumexp() = g.logsumexp()
            g.marginalize(right=n).event_logsumexp() = g.logsumexp()

        and for data ``x``:

            g.condition(x).event_logsumexp()
              = g.marginalize(left=g.dim() - x.size(-1)).log_density(x)
        """
        if left == 0 and right == 0:
            return self
        if left > 0 and right > 0:
            raise NotImplementedError
        n = self.dim()
        n_b = left + right
        a = slice(left, n - right)  # preserved
        b = slice(None, left) if left else slice(n - right, None)

        P_aa = self.precision[..., a, a]
        P_ba = self.precision[..., b, a]
        P_bb = self.precision[..., b, b]
        P_b = cholesky(P_bb)
        P_a = triangular_solve(P_ba, P_b, upper=False)
        P_at = P_a.transpose(-1, -2)
        precision = P_aa - matmul(P_at, P_a)

        info_a = self.info_vec[..., a]
        info_b = self.info_vec[..., b]
        b_tmp = triangular_solve(info_b.unsqueeze(-1), P_b, upper=False)
        info_vec = info_a - matmul(P_at, b_tmp).squeeze(-1)

        log_normalizer = (self.log_normalizer +
                          0.5 * n_b * math.log(2 * math.pi) -
                          P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1) +
                          0.5 * b_tmp.squeeze(-1).pow(2).sum(-1))
        return Gaussian(log_normalizer, info_vec, precision)