示例#1
0
def trace(A, F=None):
    """
    Compute the trace of the matrix A.
    This will broadcast the first two extra dimensions if necessary

    :param A: input matrix
    :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray).
    :return: trace of the matrix
    """
    F = get_default_MXNet_mode() if F is None else F

    if A.ndim > 4:
        raise ValueError("Cannot broadcast more than two dimensions")

    if A.ndim == 2:
        F.sum(F.diag(A))

    # TODO: make use of the nd-array support for diag when it's available
    # see: https://github.com/apache/incubator-mxnet/pull/12430

    result = F.zeros(A.shape[:-2])

    if A.ndim == 3:
        for i in range(A.shape[1]):
            result[i] = F.sum(F.diag(A[i, :, :]))

    if A.ndim == 4:
        for i in range(A.shape[0]):
            for j in range(A.shape[1]):
                result[i, j] = F.sum(F.diag(A[i, j, :, :]))

    return result
示例#2
0
def log_multivariate_gamma(x, p, F=None):
    """
    Compute the log of the multivariate gamma function of dimension p
    See https://en.wikipedia.org/wiki/Multivariate_gamma_function
    This will broadcast the first two extra dimensions if necessary

    :param x: input variable
    :param p: dimension
    :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray).
    :return: log of multivariate gamma function
    """
    F = get_default_MXNet_mode() if F is None else F

    def log_gamma_sum(a):
        # Sum over univariate log gamma functions
        if p == 1:
            # note that \Gamma_1(x) reduces to the ordinary gamma function
            return F.gammaln(a)
        r = F.zeros(p)
        for k in range(1, p + 1):
            r[k - 1] = F.gammaln(a + ((1. - k) / 2))
        return F.sum(r)

    p_mx = F.array([p], dtype=x.dtype)

    # leading constant
    c = p_mx * (p_mx - 1) / 4 * np.log(np.pi)

    shape = x.shape
    x_flat = x.reshape(-1)
    result_flat = F.zeros(x_flat.shape)
    for i in range(len(x_flat)):
        result_flat[i] = log_gamma_sum(x_flat[i])
    result = result_flat.reshape(shape)
    return result + c
示例#3
0
def log_determinant(A, F=None):
    """
    Compute log determinant of the positive semi-definite matrix A. This uses the fact that:

    log|A| = log|LL'| = log|L||L'| = log|L|^2 = 2log|L| = 2 tr(L)
    where L is the Cholesky factor of A

    :param A: Positive semi-definite matrix
    :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray).
    :return: Log determinant
    """
    F = get_default_MXNet_mode() if F is None else F

    return 2 * F.linalg.sumlogdiag(F.linalg.potrf(A))
示例#4
0
def solve(A, B, F=None):
    """
    Solve the equation AX = B: X = A^{-1}B
    To compute tr(V{-1} X) we'll first compute the Cholesky decomposition of V:
    A = V^{-1}X
    => VA = X
    LL'A = X
    Then we solve two linear systems involving L:
    Ly = X
    L'A = y

    :param A: input matrix
    :param B: input matrix or vector
    :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray).
    :return:
    """
    F = get_default_MXNet_mode() if F is None else F

    L = F.linalg.potrf(A)
    y = F.linalg.trsm(L, B)
    X = F.linalg.trsm(L, y, transpose=1)
    return X