def pd_inv(a: Woodbury): diag_inv = B.inv(a.diag) # See comment in `inv`. return B.subtract( diag_inv, LowRank( B.matmul(diag_inv, a.lr.left), B.matmul(diag_inv, a.lr.right), B.pd_inv(B.pd_schur(a)), ), )
def from_normal(cls, dist): """Construct from a normal distribution. Args: dist (distribution): Normal distribution to construct from. Returns: :class:`.NaturalNormal`: Normal distribution parametrised by the natural parameters of `dist`. """ return cls(B.cholsolve(B.chol(dist.var), dist.mean), B.pd_inv(dist.var))
def pd_schur(a: Woodbury): """Compute the Schur complement associated to a positive-definite matrix. A Schur complement will need to make sense for the type of `a`. Args: a (matrix): Matrix to compute Schur complement of. Returns: matrix: Schur complement. """ if a.schur is None: second = B.mm(a.lr.right, B.inv(a.diag), a.lr.left, tr_a=True) a.schur = B.add(B.pd_inv(a.lr.middle), second) return a.schur
def _project_pattern(self, x, y, pattern): # Check whether all data is available. no_missing = all(pattern) if no_missing: # All data is available. Nothing to be done. u = self.u else: # Data is missing. Pick the available entries. y = B.take(y, pattern, axis=1) # Ensure that `u` remains a structured matrix. u = Dense(B.take(self.u, pattern)) # Get number of data points and outputs in this part of the data. n = B.shape(x)[0] p = sum(pattern) # Perform projection. proj_y_partial = B.matmul(y, B.pinv(u), tr_b=True) proj_y = B.matmul(proj_y_partial, B.inv(self.s_sqrt), tr_b=True) # Compute projected noise. u_square = B.matmul(u, u, tr_a=True) proj_noise = ( self.noise_obs / B.diag(self.s_sqrt) ** 2 * B.diag(B.pd_inv(u_square)) ) # Convert projected noise to weights. noises = self.model.noises weights = noises / (noises + proj_noise) proj_w = B.ones(B.dtype(weights), n, self.m) * weights[None, :] # Compute Frobenius norm. frob = B.sum(y ** 2) frob = frob - B.sum(proj_y_partial * B.matmul(proj_y_partial, u_square)) # Compute regularising term. reg = 0.5 * ( n * (p - self.m) * B.log(2 * B.pi * self.noise_obs) + frob / self.noise_obs + n * B.logdet(B.matmul(u, u, tr_a=True)) + n * 2 * B.logdet(self.s_sqrt) ) return x, proj_y, proj_w, reg
def test_pd_inv_correctness(dense_pd): approx(B.pd_inv(dense_pd), B.inv(dense_pd))
def var(self): """matrix: Variance.""" if self._var is None: self._var = B.pd_inv(self.prec) return self._var
def iqf(a: Woodbury, b, c): return B.mm(b, B.pd_inv(a), c, tr_a=True)
def pd_inv(a: Kronecker): return Kronecker(B.pd_inv(a.left), B.pd_inv(a.right))