예제 #1
0
파일: qr.py 프로젝트: ithanlevin/dfact
def __block_svd_iteration(A, V, s):
    Ql, Rl = jnp.linalg.qr(A @ V, mode="reduced")
    U = None
    U = Ql[:, :s]

    Qr, Rr = jnp.linalg.qr(dag(A) @ U, mode="reduced")
    Sigma = jnp.diag(Rr[:s, :s])
    V = None
    V = Qr[:, :s]

    err_l = A @ V
    err_r = U * Sigma
    err = jnp.linalg.norm(jnp.abs(err_l - err_r))
    return [U, Sigma, V, err]
예제 #2
0
def randSVD(A, k=None, p=5, n_iter=2):
    """
    Implements the 'randSVD' algorithm, approximating the full SVD of A
    via random sampling methods. I wrote this following the version in the
    randUTV talk.

    Arguments
    ---------
    A (numpy array): The m x n matrix to be factorized.
    k (int)        : Rank of the output. k=n if unspecified or if k>n.
    p (int)        : Oversampling parameter. In the throughput, we take
                     k -> k + p. Larger values entail better, but
                     slower, results.
    n_iter (int)   : Number of power iterations


    Exceptions
    ----------
    ValueError when A is not a two-dimensional matrix.
    AssertionError unless k > 0, p>=0, n_iter>=0.

    Returns
    -------
    List [U (m x k), S (k), Vs (k x n)] where
        A ~ U * diag(S) * Vs after padding k back to n.
    """
    try:
        m, n = A.shape
    except ValueError:
        raise ValueError("A had invalid shape: ", A.shape)

    if k is None or k > n:
        k = n

    assert k > 0
    assert p >= 0
    assert n_iter >= 0

    Q = rand_range_col(A, b=k + p, n_iter=n_iter)
    B = jnp.dot(dag(Q), A)
    Utilde, S, Vdag = jnp.linalg.svd(B, full_matrices=False)
    U = jnp.dot(Q, Utilde)

    U = U[:, :k]
    S = S[:k]
    Vdag = Vdag[:k, :]
    output = [U, S, Vdag]
    return output
예제 #3
0
def randUTV_slow(A, b, q):
    T = A
    m, n = A.shape
    U = jnp.eye(m, dtype=A.dtype)
    V = jnp.eye(n, dtype=A.dtype)
    for i in range(math.ceil(n / b)):
        bidx = b * i
        Tb = T[bidx:, bidx:]
        if n - bidx > b:
            UU, TT, VV = stepUTV_slow(Tb, b=b, n_iter=q)

        else:
            UU, TTs, VVh = jnp.linalg.svd(Tb, full_matrices=True)
            VV = dag(VVh)
            TT = jnp.zeros(Tb.shape, A.dtype)
            TTd = jnp.diag(TTs)
            TT = index_update(TT, index[:TTd.shape[0], :TTd.shape[1]], TTd)
        U = index_update(U, index[:, bidx:], jnp.dot(U[:, bidx:], UU))
        V = index_update(V, index[:, bidx:], jnp.dot(V[:, bidx:], VV))
        T = index_update(T, index[bidx:, bidx:], TT)
        T = index_update(T, index[:bidx, bidx:], jnp.dot(T[:bidx, bidx:], VV))
    return [U, T, V]
예제 #4
0
파일: qr.py 프로젝트: ithanlevin/dfact
def house_rightmult(A, v, beta):
    """
    Given the m x n matrix A and the length-n vector v with normalization
    beta such that P = I - beta v otimes dag(v) is the Householder matrix that
    reflects about v, compute AP.

    Parameters
    ----------
    A:  array_like, shape(M, N)
        Matrix to be multiplied by H.

    v:  array_like, shape(N).
        Householder vector.

    beta: float
        Householder normalization.

    Returns
    -------
    C = AP
    """
    C = A - jnp.outer(A @ v, beta * dag(v))
    return C
예제 #5
0
def __randUTV_work(A, b, q, p):
    """
    Performs the "optimized" randUTV in Figure4 of the paper.

    Arguments
    ---------
    A: (m x n) matrix to be factorized.
    b (int): block size
    q (int): Number of power iterations, a hyperparameter.
    p (int): Amount of oversampling, a hyperparameter. 
    """

    m, n = A.shape

    # Initialize output variables:
    U = jnp.eye(m, dtype=A.dtype)
    T = A
    V = jnp.eye(n, dtype=A.dtype)

    B1s, B2s, B3s, B2B3s = initialize_slices(T, b)
    mindim = jnp.min(T.shape)
    bj0 = 0  # Passes final value to next for loop.
    for bj in range(0, mindim - b, b):
        bj0 = bj
        # During this for loop, we create and apply transformation matrices
        # bringing the j'th b x b diagonal block of T to diagonal form.
        # The loop terminates when the next diagonal block would either be
        # empty or smaller than b x b, in which case we execute the code
        # within the next for loop. We use a pair of for loops to avoid
        # the awkward interplay between conditionals and jit.
        j = bj // b
        B1, B2, B3, B2B3 = [B1s[j], B2s[j], B3s[j], B2B3s[j]]

        thisblock = T[B2B3, B2B3]

        # Use randomized sampling methods to generate a unitary matrix Vj
        # whose columns form an approximate orthonormal basis for those of
        # T([I2, J3], [J2, J3]); that is, the portion of A which is not
        # yet diagonal-ish. Vj is in its WY QR representation,
        # that is, as two matrices Vj_W and Vj_YH.
        Vj_W, Vj_YH, _ = rand_range_row_jit(thisblock, b, q, p)

        # Compute T = T @ Vj and V = V @ Vj using the function
        # qr.B_times_Q_WY, which does B @ Q with Q in the WY representation.
        # Since V is initially the identity, this builds up
        # V=V0@V0s@V1@V0s@V2... ,
        # so that V inherits the unitarity of its constituents. T@dag(V)
        # then reverses the procedure. V0s, which is also unitary,  is computed
        # in the final step of the for loop.
        T = index_update(T, index[:, B2B3],
                         qr.B_times_Q_WY(T[:, B2B3], Vj_W, Vj_YH))
        V = index_update(V, index[:, B2B3],
                         qr.B_times_Q_WY(V[:, B2B3], Vj_W, Vj_YH))

        # Build an orthonormal/unitary matrix Uj in similar fashion, and
        # compute U = U@Uj, T = dag(Uj)@T. Thus, U @ T again reverses the
        # procedure, while U remains unitary. Uj is also in its WY
        # representation. This time, we hang onto the matrix R in the QR
        # decomposition for later use.
        Uj_W, Uj_YH, Uj_R = qr.house_qr(T[B2B3, B2], mode="WY")
        U = index_update(U, index[:, B2B3],
                         qr.B_times_Q_WY(U[:, B2B3], Uj_W, Uj_YH))
        T = index_update(T, index[B2B3, B3],
                         qr.Qdag_WY_times_B(T[B2B3, B3], Uj_W, Uj_YH))
        # Zero out entries of T beneath the current block diagonal.
        T = index_update(T, index[B3, B2], 0.)

        # Uj_R[:b, :b] is now the portion of the active diagonal block which
        # we have not yet absorbed into U, T, or V. Diagonalize it with
        # an SVD to yield 'small' matrices Us@Ds@Vsh = svd(Uj_R[:b, :b].
        # T[I2, J2] = Ds thus diagonalizes the active block. Absorb
        # the unitary matrices Us and Vsh into U, T, and V so that the
        # transformation is reversed during A = U @ T @ dag(V).
        Us, Ds, Vsh = jnp.linalg.svd(Uj_R[:b, :b])
        Vs = dag(Vsh)
        T = index_update(T, index[B2, B2], jnp.diag(Ds))
        T = index_update(T, index[B2, B3], dag(Us) @ T[B2, B3])
        U = index_update(U, index[:, B2], U[:, B2] @ Us)
        T = index_update(T, index[B1, B2], T[B1, B2] @ Vs)
        V = index_update(V, index[:, B2], V[:, B2] @ Vs)

    for bj in range(bj0 + b, mindim, b):
        # This 'loop' operates on the last diagonal block in the case that
        # b did not divide either m or n evenly. It performs the SVD
        # step at the end of the 'main' block, accomodating the relevant
        # matrix dimensions. This loop should only ever increment either
        # never or once and
        # would more naturally be an if statement, but Jit doesn't like that.
        B1 = B1s[-1]
        B2B3 = B2B3s[-1]
        thisblock = T[B2B3, B2B3]

        Us, Dvals, Vsh = jnp.linalg.svd(thisblock, full_matrices=True)
        Vs = dag(Vsh)

        U = index_update(U, index[:, B2B3], U[:, B2B3] @ Us)
        V = index_update(V, index[:, B2B3], V[:, B2B3] @ Vs)

        idxs = matutils.subblock_main_diagonal(T, bi=bj)
        allDs = jnp.zeros(idxs[0].size)
        allDs = index_update(allDs, index[:Dvals.size], Dvals)
        T = index_update(T, index[B2B3, B2B3], 0.)
        T = index_update(T, idxs, allDs)
        T = index_update(T, index[B1, B2B3], T[B1, B2B3] @ Vs)
    return [U, T, V]
예제 #6
0
def __randUTV_workforjit(A, b, q, p):
    """
    Performs the "optimized" randUTV in Figure4 of the paper.

    Arguments
    ---------
    A: (m x n) matrix to be factorized.
    Gwork: (m x b) matrix that will be used as a work space for the
           randomized range finder.
    b (int): block size
    q (int): Number of power iterations, a hyperparameter.
    p (int): Amount of oversampling, a hyperparameter. 
    """

    m, n = A.shape

    # Initialize output variables:
    U = jnp.eye(m, dtype=A.dtype)
    T = A
    V = jnp.eye(n, dtype=A.dtype)

    B1s, B2s, B3s, B2B3s = initialize_slices(T, b)
    mindim = jnp.min(T.shape)
    bj0 = 0  # Passes final value to next for loop.
    for bj in range(0, mindim - b, b):
        bj0 = bj
        # During this for loop, we create and apply transformation matrices
        # bringing the j'th b x b diagonal block of T to diagonal form.
        # The loop terminates when the next diagonal block would either be
        # empty or smaller than b x b, in which case we execute the code
        # within the next for loop. We use a pair of for loops to avoid
        # the awkward interplay between conditionals and jit.
        j = bj // b
        B1, B2, B3, B2B3 = [B1s[j], B2s[j], B3s[j], B2B3s[j]]

        thisblock = T[B2B3, B2B3]

        T_B2B3 = T[:, B2B3]
        V_B2B3 = V[:, B2B3]
        T, V = __randUTV_block_step1(bj, b, q, p, T, V, thisblock, T_B2B3,
                                     V_B2B3)

        T_B2B3_B2 = T[B2B3, B2]
        U_B2B3 = U[:, B2B3]
        T_B2B3_B3 = T[B2B3, B3]
        T_B3_B2 = T[B3, B2]
        Tzeros = jnp.zeros(T[B3, B2].shape, dtype=A.dtype)
        Us, Vs, U, T, V = __randUTV_block_step2(bj, b, U, T, V, T_B2B3_B2,
                                                U_B2B3, T_B2B3_B3, T_B3_B2,
                                                Tzeros)
        T_B2_B2 = T[B2, B2]
        T_B2_B3 = T[B2, B3]
        U_B2 = U[:, B2]
        T_B1_B2 = T[B1, B2]
        V_B2 = V[:, B2]
        U, T, V = __randUTV_block_step3(bj, b, Us, Vs, U, T, V, T_B2_B2,
                                        T_B2_B3, U_B2, T_B1_B2, V_B2)

    for bj in range(bj0 + b, mindim, b):
        # This 'loop' operates on the last diagonal block in the case that
        # b did not divide either m or n evenly. It performs the SVD
        # step at the end of the 'main' block, accomodating the relevant
        # matrix dimensions. This loop should only ever increment either
        # never or once and
        # would more naturally be an if statement, but Jit doesn't like that.
        B1 = B1s[-1]
        B2B3 = B2B3s[-1]
        thisblock = T[B2B3, B2B3]

        Us, Dvals, Vsh = jnp.linalg.svd(thisblock, full_matrices=True)
        Vs = dag(Vsh)

        U = index_update(U, index[:, B2B3], U[:, B2B3] @ Us)
        V = index_update(V, index[:, B2B3], V[:, B2B3] @ Vs)

        idxs = matutils.subblock_main_diagonal(T, bi=bj)
        allDs = jnp.zeros(idxs[0].size)
        allDs = index_update(allDs, index[:Dvals.size], Dvals)
        T = index_update(T, index[B2B3, B2B3], 0.)
        T = index_update(T, idxs, allDs)
        T = index_update(T, index[B1, B2B3], T[B1, B2B3] @ Vs)
    return [U, T, V]
예제 #7
0
def stepUTV_slow(A, b=None, p=5, n_iter=1, verbose=False):
    """
    Perfoms one step of the randUTV algorithm using the 'slow' method
    of Figure 3.

    This algorithm applies the UTV decomposition to one block of size
    b. If b is None, the entire matrix is decomposed.

    Arguments
    ---------
    A (numpy array): The m x n matrix to be factorized.
    b (int)        : Block size of the output.
    p (int)        : Oversampling parameter.
    n_iter (int)   : Number of power iterations for the Gaussian sampling.


    Exceptions
    ----------
    ValueError when A is not a two-dimensional matrix.
    AssertionError unless b > 0, p>=0, n_iter>=0.

    Returns
    -------
    List [U (m x m), T (m x n), dag(V) (n x n)] where
        A = U @ T @ dag(V) .
    """

    try:
        m, n = A.shape
    except ValueError:
        raise ValueError("A had invalid shape: ", A.shape)

    if b is None or b > n:
        b = n

    if m < n:
        raise NotImplementedError(
            "m < n case of stepUTV_slow not implemented.")
    #assert m >= n
    assert b > 0
    assert p >= 0
    assert n_iter >= 0

    V, _ = rand_range_row(A, b=b + p, n_iter=n_iter,
                          mode="complete")  # (n x n)

    # First b columns approximately span the singular value space of
    # A.
    AV = jnp.dot(A, V)
    AV1 = jnp.dot(A, V[:, :b])
    AV2 = jnp.dot(A, V[:, b:])
    U, T11, Vsmall_dH = jnp.linalg.svd(AV1)  # (m x m), min(m, b), (b x b)
    Vsmall_d = dag(Vsmall_dH)

    Tright = jnp.dot(dag(U), AV2)
    T = jnp.zeros((m, n), dtype=A.dtype)
    T = jax.ops.index_update(T, jax.ops.index[:b, :b], jnp.diag(T11))
    T = jax.ops.index_update(T, jax.ops.index[:, b:], Tright)
    V = jax.ops.index_update(V, jax.ops.index[:, :b],
                             jnp.dot(V[:, :b], Vsmall_d))

    if verbose:
        print("*************")
        print("AV:", AV)
        print("AV1:", AV1, "AV2:", AV2)
        print("U:", U, "T11:", T11, "Vs:", Vsmall_d)
        print("Tright:", Tright)
        print("T:", T)
        print("V:", V)
        print("*************")
    output = [U, T, V]
    return output