Ejemplo n.º 1
0
def _lu_blocked(a, block_size=32):
    """Blocked LU decomposition, as an unrolled loop."""
    m, n = a.shape
    r = min(m, n)
    pivot = np.zeros((r, ), dtype=np.int32)
    error = np.array(False, np.bool_)
    for k in range(0, r, block_size):
        b = min(r - k, block_size)
        block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:,
                                                                   k:k + b])
        error = error | block_error
        a = ops.index_update(a, ops.index[k:, k:k + b], lu_block)

        a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k])
        pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k)

        if k + b < n:
            a = ops.index_update(a, ops.index[k:, k + b:], a[perm + k, k + b:])
            a = ops.index_update(
                a, ops.index[k:k + b, k + b:],
                triangular_solve(a[k:k + b, k:k + b],
                                 a[k:k + b, k + b:],
                                 left_side=True,
                                 lower=True,
                                 unit_diagonal=True))
            a = ops.index_add(
                a, ops.index[k + b:, k + b:],
                -lax.dot(a[k + b:, k:k + b],
                         a[k:k + b, k + b:],
                         precision=lax.Precision.HIGHEST))
    a = np.where(error, lax.full_like(a, np.nan), a)
    return pivot, a
Ejemplo n.º 2
0
def _lu_blocked(a, block_size=128):
    """Blocked LU decomposition, as an unrolled loop."""
    m, n = a.shape
    r = min(m, n)
    pivot = np.zeros((r, ), dtype=np.int32)
    for k in range(0, r, block_size):
        b = min(r - k, block_size)
        block_pivot, perm, lu_block = _lu_unblocked(a[k:, k:k + b])

        a = ops.index_update(a, ops.index[k:, :], a[perm + k, :])
        a = ops.index_update(a, ops.index[k:, k:k + b], lu_block)
        pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k)

        if k + b < n:
            a = ops.index_update(
                a, ops.index[k:k + b, k + b:],
                triangular_solve(a[k:k + b, k:k + b],
                                 a[k:k + b, k + b:],
                                 left_side=True,
                                 lower=True,
                                 unit_diagonal=True))
            a = ops.index_add(
                a, ops.index[k + b:, k + b:],
                -lax.dot(a[k + b:, k:k + b],
                         a[k:k + b, k + b:],
                         precision=lax.Precision.HIGHEST))
    return pivot, a
Ejemplo n.º 3
0
def _lu_unblocked(a):
    """Unblocked LU decomposition, as a rolled loop."""
    m, n = a.shape

    def body(k, state):
        pivot, perm, a, error = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
        else:
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        error = error | lax.eq(x, np._constant_like(a, 0))
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a, error

    pivot = np.zeros((min(m, n), ), dtype=np.int32)
    perm = np.arange(m, dtype=np.int32)
    error = np.array(False, np.bool_)
    if m == 0 and n == 0:
        # If the array is empty, the loop body never executes but tracing it to a
        # jaxpr fails because the indexing cannot succeed.
        return (pivot, perm, a, error)
    return lax.fori_loop(0, min(m, n), body, (pivot, perm, a, error))