def gaussian_tensordot(x, y, dims=0): """ Computes the integral over two gaussians: `(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b))`, where `x` is a gaussian over variables (a,b), `y` is a gaussian over variables (b,c), (a,b,c) can each be sets of zero or more variables, and `dims` is the size of b. :param x: a Gaussian instance :param y: a Gaussian instance :param dims: number of variables to contract """ assert isinstance(x, Gaussian) assert isinstance(y, Gaussian) na = x.dim() - dims nb = dims nc = y.dim() - dims assert na >= 0 assert nb >= 0 assert nc >= 0 Paa, Pba, Pbb = ( x.precision[..., :na, :na], x.precision[..., na:, :na], x.precision[..., na:, na:], ) Qbb, Qbc, Qcc = ( y.precision[..., :nb, :nb], y.precision[..., :nb, nb:], y.precision[..., nb:, nb:], ) xa, xb = x.info_vec[..., :na], x.info_vec[..., na:] # x.precision @ x.mean yb, yc = y.info_vec[..., :nb], y.info_vec[..., nb:] # y.precision @ y.mean precision = pad(Paa, (0, nc, 0, nc)) + pad(Qcc, (na, 0, na, 0)) info_vec = pad(xa, (0, nc)) + pad(yc, (na, 0)) log_normalizer = x.log_normalizer + y.log_normalizer if nb > 0: B = pad(Pba, (0, nc)) + pad(Qbc, (na, 0)) b = xb + yb # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral) L = cholesky(Pbb + Qbb) LinvB = triangular_solve(B, L, upper=False) LinvBt = LinvB.transpose(-2, -1) Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False) precision = precision - matmul(LinvBt, LinvB) # NB: precision might not be invertible for getting mean = precision^-1 @ info_vec if na + nc > 0: info_vec = info_vec - matmul(LinvBt, Linvb).squeeze(-1) logdet = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1) diff = (0.5 * nb * math.log(2 * math.pi) + 0.5 * Linvb.squeeze(-1).pow(2).sum(-1) - logdet) log_normalizer = log_normalizer + diff return Gaussian(log_normalizer, info_vec, precision)
def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian): P_yy = y_gaussian.precision neg_P_xy = matmul(matrix, P_yy) P_xy = -neg_P_xy P_yx = P_xy.transpose(-1, -2) P_xx = matmul(neg_P_xy, matrix.transpose(-1, -2)) precision = torch.cat([torch.cat([P_xx, P_xy], -1), torch.cat([P_yx, P_yy], -1)], -2) info_y = y_gaussian.info_vec info_x = -matvecmul(matrix, info_y) info_vec = torch.cat([info_x, info_y], -1) log_normalizer = y_gaussian.log_normalizer result = Gaussian(log_normalizer, info_vec, precision) return result
def condition(self, value): if value.size(-1) == self.loc.size(-1): prec_sqrt = self.matrix / self.scale.unsqueeze(-2) precision = matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) delta = (value - self.loc) / self.scale info_vec = matvecmul(prec_sqrt, delta) log_normalizer = (-0.5 * self.loc.size(-1) * math.log(2 * math.pi) - 0.5 * delta.pow(2).sum(-1) - self.scale.log().sum(-1)) return Gaussian(log_normalizer, info_vec, precision) else: return self.to_gaussian().condition(value)
def marginalize(self, left=0, right=0): """ Marginalizing out variables on either side of the event dimension:: g.marginalize(left=n).event_logsumexp() = g.logsumexp() g.marginalize(right=n).event_logsumexp() = g.logsumexp() and for data ``x``: g.condition(x).event_logsumexp() = g.marginalize(left=g.dim() - x.size(-1)).log_density(x) """ if left == 0 and right == 0: return self if left > 0 and right > 0: raise NotImplementedError n = self.dim() n_b = left + right a = slice(left, n - right) # preserved b = slice(None, left) if left else slice(n - right, None) P_aa = self.precision[..., a, a] P_ba = self.precision[..., b, a] P_bb = self.precision[..., b, b] P_b = cholesky(P_bb) P_a = triangular_solve(P_ba, P_b, upper=False) P_at = P_a.transpose(-1, -2) precision = P_aa - matmul(P_at, P_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] b_tmp = triangular_solve(info_b.unsqueeze(-1), P_b, upper=False) info_vec = info_a - matmul(P_at, b_tmp).squeeze(-1) log_normalizer = (self.log_normalizer + 0.5 * n_b * math.log(2 * math.pi) - P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1) + 0.5 * b_tmp.squeeze(-1).pow(2).sum(-1)) return Gaussian(log_normalizer, info_vec, precision)