def body(k, state): pivot, perm, a = state m_idx = np.arange(m) n_idx = np.arange(n) if np.issubdtype(a.dtype, np.complexfloating): t = a[:, k] magnitude = np.abs(np.real(t)) + np.abs(np.imag(t)) else: magnitude = np.abs(a[:, k]) i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], np.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:]) a = a - np.where( (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype)) return pivot, perm, a
def _lu_blocked(a, block_size=32): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) pivot = np.zeros((r, ), dtype=np.int32) error = np.array(False, np.bool_) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:, k:k + b]) error = error | block_error a = ops.index_update(a, ops.index[k:, k:k + b], lu_block) a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k]) pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k) if k + b < n: a = ops.index_update(a, ops.index[k:, k + b:], a[perm + k, k + b:]) a = ops.index_update( a, ops.index[k:k + b, k + b:], triangular_solve(a[k:k + b, k:k + b], a[k:k + b, k + b:], left_side=True, lower=True, unit_diagonal=True)) a = ops.index_add( a, ops.index[k + b:, k + b:], -lax.dot(a[k + b:, k:k + b], a[k:k + b, k + b:], precision=lax.Precision.HIGHEST)) a = np.where(error, lax.full_like(a, np.nan), a) return pivot, a
def lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax._dtype(a) k = min(m, n) # TODO(phawkins): use a gather rather than a matrix multiplication here. permutation = lu_pivots_to_permutation(pivots, m) p = np.array(permutation[:, None] == np.arange(m), dtype=dtype) x = np.matmul(p, a_dot) # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
def _lu_unblocked(a): """Unblocked LU decomposition, as a rolled loop.""" m, n = a.shape def body(k, state): pivot, perm, a, error = state m_idx = np.arange(m) n_idx = np.arange(n) if np.issubdtype(a.dtype, np.complexfloating): t = a[:, k] magnitude = np.abs(np.real(t)) + np.abs(np.imag(t)) else: magnitude = np.abs(a[:, k]) i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] error = error | lax.eq(x, np._constant_like(a, 0)) a = ops.index_update(a, ops.index[:, k], np.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:]) a = a - np.where( (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype)) return pivot, perm, a, error pivot = np.zeros((min(m, n), ), dtype=np.int32) perm = np.arange(m, dtype=np.int32) error = np.array(False, np.bool_) if m == 0 and n == 0: # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a, error) return lax.fori_loop(0, min(m, n), body, (pivot, perm, a, error))
def all_gather(x, axis_name, *, axis_index_groups=None): """Gather values of x across all replicas. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This is equivalent to, but faster than, all_to_all(broadcast(x)). Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) representing the result of an all-gather along the axis ``axis_name``. Shapes are the same as ``x.shape``, but with a leading dimension of the axis_size. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]] An example of using axis_index_groups, groups split by even & odd device ids: >>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.]] >>> y = jax.pmap(lambda x: jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x) >>> print(y) [[[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] [[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] """ index = axis_index(axis_name) if axis_index_groups is not None: indices = np.array(axis_index_groups).flatten() axis_index_to_group_index = indices.argsort() % len( axis_index_groups[0]) index = lax_numpy.array(axis_index_to_group_index)[index] axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) return _allgather(x, 0, axis_size, index, axis_name, axis_index_groups)