コード例 #1
0
def lowrank_log_likelihood(rank: int, mu: Tensor, D: Tensor, W: Tensor,
                           x: Tensor) -> Tensor:

    F = getF(mu)

    dim = F.ones_like(mu).sum(axis=-1).max()

    dim_factor = dim * math.log(2 * math.pi)

    if W is not None:
        batch_capacitance_tril = capacitance_tril(F=F, rank=rank, W=W, D=D)

        log_det_factor = log_det(F=F,
                                 batch_D=D,
                                 batch_capacitance_tril=batch_capacitance_tril)

        mahalanobis_factor = mahalanobis_distance(
            F=F, W=W, D=D, capacitance_tril=batch_capacitance_tril, x=x - mu)
    else:
        log_det_factor = D.log().sum(axis=-1)
        x_centered = x - mu
        mahalanobis_factor = F.broadcast_div(x_centered.square(),
                                             D).sum(axis=-1)

    ll: Tensor = -0.5 * (F.broadcast_add(dim_factor, log_det_factor) +
                         mahalanobis_factor)

    return ll
コード例 #2
0
def log_det(F, batch_D: Tensor, batch_capacitance_tril: Tensor) -> Tensor:
    r"""
    Uses the matrix determinant lemma.

    .. math::
        \log|D + W W^T| = \log|C| + \log|D|,

    where :math:`C` is the capacitance matrix :math:`I + W^T D^{-1} W`, to compute the log determinant.

    Parameters
    ----------
    F
    batch_D
    batch_capacitance_tril

    Returns
    -------

    """
    log_D = batch_D.log().sum(axis=-1)
    log_C = 2 * F.linalg.sumlogdiag(batch_capacitance_tril)
    return log_C + log_D