Esempio n. 1
0
    def body(k, state):
        pivot, perm, a = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
        else:
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a
Esempio n. 2
0
def _lu_blocked(a, block_size=128):
    """Blocked LU decomposition, as an unrolled loop."""
    m, n = a.shape
    r = min(m, n)
    pivot = jnp.zeros((r, ), dtype=jnp.int32)
    perm = jnp.arange(m, dtype=jnp.int32)
    for k in range(0, r, block_size):
        b = min(r - k, block_size)
        block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k + b])

        pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k)
        perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k])
        a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :])
        a = ops.index_update(a, ops.index[k:, k:k + b], lu_block)

        if k + b < n:
            a = ops.index_update(
                a, ops.index[k:k + b, k + b:],
                triangular_solve(a[k:k + b, k:k + b],
                                 a[k:k + b, k + b:],
                                 left_side=True,
                                 lower=True,
                                 unit_diagonal=True))
            a = ops.index_add(
                a, ops.index[k + b:, k + b:],
                -lax.dot(a[k + b:, k:k + b],
                         a[k:k + b, k + b:],
                         precision=lax.Precision.HIGHEST))
    return a, pivot, perm
Esempio n. 3
0
def lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax._dtype(a)
    k = min(m, n)

    # TODO(phawkins): use a gather rather than a matrix multiplication here.
    permutation = lu_pivots_to_permutation(pivots, m)
    p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
    x = np.matmul(p, a_dot)

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
Esempio n. 4
0
def lu_pivots_to_permutation(swaps, k):
  """Converts the pivots (row swaps) returned by LU to a permutation."""

  def body_fn(i, loop_carry):
    swaps, permutation = loop_carry
    j = swaps[i]
    x, y = np.ravel(permutation[i]), np.ravel(permutation[j])
    permutation = lax.dynamic_update_index_in_dim(permutation, y, i, axis=0)
    permutation = lax.dynamic_update_index_in_dim(permutation, x, j, axis=0)
    return swaps, permutation

  n, = np.shape(swaps)
  permutation = np.arange(k)
  _, permutation = lax.fori_loop(onp.array(0, onp.int32), onp.array(n, onp.int32),
                                 body_fn, (swaps, permutation))
  return permutation
Esempio n. 5
0
def _lu_unblocked(a):
    """Unblocked LU decomposition, as a rolled loop."""
    m, n = a.shape

    def body(k, state):
        pivot, perm, a, error = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
        else:
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        error = error | lax.eq(x, np._constant_like(a, 0))
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a, error

    pivot = np.zeros((min(m, n), ), dtype=np.int32)
    perm = np.arange(m, dtype=np.int32)
    error = np.array(False, np.bool_)
    if m == 0 and n == 0:
        # If the array is empty, the loop body never executes but tracing it to a
        # jaxpr fails because the indexing cannot succeed.
        return (pivot, perm, a, error)
    return lax.fori_loop(0, min(m, n), body, (pivot, perm, a, error))
Esempio n. 6
0
def _process_axis_index(self, frame):
    return batching.BatchTracer(self,
                                lax_numpy.arange(frame.size, dtype=np.int32),
                                0)