Beispiel #1
    def body(k, state):
        pivot, perm, a = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a
Beispiel #2
def _lu_blocked(a, block_size=32):
    """Blocked LU decomposition, as an unrolled loop."""
    m, n = a.shape
    r = min(m, n)
    pivot = np.zeros((r, ), dtype=np.int32)
    error = np.array(False, np.bool_)
    for k in range(0, r, block_size):
        b = min(r - k, block_size)
        block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:,
                                                                   k:k + b])
        error = error | block_error
        a = ops.index_update(a, ops.index[k:, k:k + b], lu_block)

        a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k])
        pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k)

        if k + b < n:
            a = ops.index_update(a, ops.index[k:, k + b:], a[perm + k, k + b:])
            a = ops.index_update(
                a, ops.index[k:k + b, k + b:],
                triangular_solve(a[k:k + b, k:k + b],
                                 a[k:k + b, k + b:],
            a = ops.index_add(
                a, ops.index[k + b:, k + b:],
      [k + b:, k:k + b],
                         a[k:k + b, k + b:],
    a = np.where(error, lax.full_like(a, np.nan), a)
    return pivot, a
Beispiel #3
def lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax._dtype(a)
    k = min(m, n)

    # TODO(phawkins): use a gather rather than a matrix multiplication here.
    permutation = lu_pivots_to_permutation(pivots, m)
    p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
    x = np.matmul(p, a_dot)

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
    lau = triangular_solve(u,

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot,
Beispiel #4
def _lu_unblocked(a):
    """Unblocked LU decomposition, as a rolled loop."""
    m, n = a.shape

    def body(k, state):
        pivot, perm, a, error = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        error = error | lax.eq(x, np._constant_like(a, 0))
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a, error

    pivot = np.zeros((min(m, n), ), dtype=np.int32)
    perm = np.arange(m, dtype=np.int32)
    error = np.array(False, np.bool_)
    if m == 0 and n == 0:
        # If the array is empty, the loop body never executes but tracing it to a
        # jaxpr fails because the indexing cannot succeed.
        return (pivot, perm, a, error)
    return lax.fori_loop(0, min(m, n), body, (pivot, perm, a, error))
Beispiel #5
def all_gather(x, axis_name, *, axis_index_groups=None):
    """Gather values of x across all replicas.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  This is equivalent to, but faster than, all_to_all(broadcast(x)).

    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.

    Array(s) representing the result of an all-gather along the axis
    ``axis_name``. Shapes are the same as ``x.shape``, but with a leading
    dimension of the axis_size.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [[0 1 2 3]
   [0 1 2 3]
   [0 1 2 3]
   [0 1 2 3]]

  An example of using axis_index_groups, groups split by even & odd device ids:

  >>> x = np.arange(16).reshape(4, 4)
  >>> print(x)
  [[ 0.  1.  2.  3.]
   [ 4.  5.  6.  7.]
   [ 8.  9. 10. 11.]
   [12. 13. 14. 15.]]
  >>> y = jax.pmap(lambda x: jax.lax.all_gather(
  ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x)
  >>> print(y)
  [[[ 0.  1.  2.  3.]
    [ 8.  9. 10. 11.]]
   [[12. 13. 14. 15.]
    [ 4.  5.  6.  7.]]
   [[ 0.  1.  2.  3.]
    [ 8.  9. 10. 11.]]
   [[12. 13. 14. 15.]
    [ 4.  5.  6.  7.]]

    index = axis_index(axis_name)
    if axis_index_groups is not None:
        indices = np.array(axis_index_groups).flatten()
        axis_index_to_group_index = indices.argsort() % len(
        index = lax_numpy.array(axis_index_to_group_index)[index]

    axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)

    return _allgather(x, 0, axis_size, index, axis_name, axis_index_groups)