def gaussian_tensordot(x, y, dims=0): """ Computes the integral over two gaussians: `(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b))`, where `x` is a gaussian over variables (a,b), `y` is a gaussian over variables (b,c), (a,b,c) can each be sets of zero or more variables, and `dims` is the size of b. :param x: a Gaussian instance :param y: a Gaussian instance :param dims: number of variables to contract """ assert isinstance(x, Gaussian) assert isinstance(y, Gaussian) na = x.dim() - dims nb = dims nc = y.dim() - dims assert na >= 0 assert nb >= 0 assert nc >= 0 Paa, Pba, Pbb = ( x.precision[..., :na, :na], x.precision[..., na:, :na], x.precision[..., na:, na:], ) Qbb, Qbc, Qcc = ( y.precision[..., :nb, :nb], y.precision[..., :nb, nb:], y.precision[..., nb:, nb:], ) xa, xb = x.info_vec[..., :na], x.info_vec[..., na:] # x.precision @ x.mean yb, yc = y.info_vec[..., :nb], y.info_vec[..., nb:] # y.precision @ y.mean precision = pad(Paa, (0, nc, 0, nc)) + pad(Qcc, (na, 0, na, 0)) info_vec = pad(xa, (0, nc)) + pad(yc, (na, 0)) log_normalizer = x.log_normalizer + y.log_normalizer if nb > 0: B = pad(Pba, (0, nc)) + pad(Qbc, (na, 0)) b = xb + yb # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral) L = cholesky(Pbb + Qbb) LinvB = triangular_solve(B, L, upper=False) LinvBt = LinvB.transpose(-2, -1) Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False) precision = precision - matmul(LinvBt, LinvB) # NB: precision might not be invertible for getting mean = precision^-1 @ info_vec if na + nc > 0: info_vec = info_vec - matmul(LinvBt, Linvb).squeeze(-1) logdet = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1) diff = (0.5 * nb * math.log(2 * math.pi) + 0.5 * Linvb.squeeze(-1).pow(2).sum(-1) - logdet) log_normalizer = log_normalizer + diff return Gaussian(log_normalizer, info_vec, precision)
def event_logsumexp(self): """ Integrates out all latent state (i.e. operating on event dimensions). """ n = self.dim() chol_P = cholesky(self.precision) chol_P_u = triangular_solve(self.info_vec.unsqueeze(-1), chol_P, upper=False).squeeze(-1) u_P_u = chol_P_u.pow(2).sum(-1) return (self.log_normalizer + 0.5 * n * math.log(2 * math.pi) + 0.5 * u_P_u - chol_P.diagonal(dim1=-2, dim2=-1).log().sum(-1))
def rsample(self, sample_shape=torch.Size()): """ Reparameterized sampler. """ P_chol = cholesky(self.precision) loc = self.info_vec.unsqueeze(-1).cholesky_solve(P_chol).squeeze(-1) shape = sample_shape + self.batch_shape + (self.dim(), 1) noise = torch.randn(shape, dtype=loc.dtype, device=loc.device) noise = triangular_solve(noise, P_chol, upper=False, transpose=True).squeeze(-1) return loc + noise
def marginalize(self, left=0, right=0): """ Marginalizing out variables on either side of the event dimension:: g.marginalize(left=n).event_logsumexp() = g.logsumexp() g.marginalize(right=n).event_logsumexp() = g.logsumexp() and for data ``x``: g.condition(x).event_logsumexp() = g.marginalize(left=g.dim() - x.size(-1)).log_density(x) """ if left == 0 and right == 0: return self if left > 0 and right > 0: raise NotImplementedError n = self.dim() n_b = left + right a = slice(left, n - right) # preserved b = slice(None, left) if left else slice(n - right, None) P_aa = self.precision[..., a, a] P_ba = self.precision[..., b, a] P_bb = self.precision[..., b, b] P_b = cholesky(P_bb) P_a = triangular_solve(P_ba, P_b, upper=False) P_at = P_a.transpose(-1, -2) precision = P_aa - matmul(P_at, P_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] b_tmp = triangular_solve(info_b.unsqueeze(-1), P_b, upper=False) info_vec = info_a - matmul(P_at, b_tmp).squeeze(-1) log_normalizer = (self.log_normalizer + 0.5 * n_b * math.log(2 * math.pi) - P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1) + 0.5 * b_tmp.squeeze(-1).pow(2).sum(-1)) return Gaussian(log_normalizer, info_vec, precision)