def lowrank_log_likelihood(rank: int, mu: Tensor, D: Tensor, W: Tensor, x: Tensor) -> Tensor: F = getF(mu) dim = F.ones_like(mu).sum(axis=-1).max() dim_factor = dim * math.log(2 * math.pi) if W is not None: batch_capacitance_tril = capacitance_tril(F=F, rank=rank, W=W, D=D) log_det_factor = log_det(F=F, batch_D=D, batch_capacitance_tril=batch_capacitance_tril) mahalanobis_factor = mahalanobis_distance( F=F, W=W, D=D, capacitance_tril=batch_capacitance_tril, x=x - mu) else: log_det_factor = D.log().sum(axis=-1) x_centered = x - mu mahalanobis_factor = F.broadcast_div(x_centered.square(), D).sum(axis=-1) ll: Tensor = -0.5 * (F.broadcast_add(dim_factor, log_det_factor) + mahalanobis_factor) return ll
def log_det(F, batch_D: Tensor, batch_capacitance_tril: Tensor) -> Tensor: r""" Uses the matrix determinant lemma. .. math:: \log|D + W W^T| = \log|C| + \log|D|, where :math:`C` is the capacitance matrix :math:`I + W^T D^{-1} W`, to compute the log determinant. Parameters ---------- F batch_D batch_capacitance_tril Returns ------- """ log_D = batch_D.log().sum(axis=-1) log_C = 2 * F.linalg.sumlogdiag(batch_capacitance_tril) return log_C + log_D