def body(k, state): pivot, perm, a = state m_idx = np.arange(m) n_idx = np.arange(n) if np.issubdtype(a.dtype, np.complexfloating): t = a[:, k] magnitude = np.abs(np.real(t)) + np.abs(np.imag(t)) else: magnitude = np.abs(a[:, k]) i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], np.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:]) a = a - np.where( (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype)) return pivot, perm, a
def _lu_blocked(a, block_size=128): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) pivot = jnp.zeros((r, ), dtype=jnp.int32) perm = jnp.arange(m, dtype=jnp.int32) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k + b]) pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k) perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k]) a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :]) a = ops.index_update(a, ops.index[k:, k:k + b], lu_block) if k + b < n: a = ops.index_update( a, ops.index[k:k + b, k + b:], triangular_solve(a[k:k + b, k:k + b], a[k:k + b, k + b:], left_side=True, lower=True, unit_diagonal=True)) a = ops.index_add( a, ops.index[k + b:, k + b:], -lax.dot(a[k + b:, k:k + b], a[k:k + b, k + b:], precision=lax.Precision.HIGHEST)) return a, pivot, perm
def lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax._dtype(a) k = min(m, n) # TODO(phawkins): use a gather rather than a matrix multiplication here. permutation = lu_pivots_to_permutation(pivots, m) p = np.array(permutation[:, None] == np.arange(m), dtype=dtype) x = np.matmul(p, a_dot) # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
def lu_pivots_to_permutation(swaps, k): """Converts the pivots (row swaps) returned by LU to a permutation.""" def body_fn(i, loop_carry): swaps, permutation = loop_carry j = swaps[i] x, y = np.ravel(permutation[i]), np.ravel(permutation[j]) permutation = lax.dynamic_update_index_in_dim(permutation, y, i, axis=0) permutation = lax.dynamic_update_index_in_dim(permutation, x, j, axis=0) return swaps, permutation n, = np.shape(swaps) permutation = np.arange(k) _, permutation = lax.fori_loop(onp.array(0, onp.int32), onp.array(n, onp.int32), body_fn, (swaps, permutation)) return permutation
def _lu_unblocked(a): """Unblocked LU decomposition, as a rolled loop.""" m, n = a.shape def body(k, state): pivot, perm, a, error = state m_idx = np.arange(m) n_idx = np.arange(n) if np.issubdtype(a.dtype, np.complexfloating): t = a[:, k] magnitude = np.abs(np.real(t)) + np.abs(np.imag(t)) else: magnitude = np.abs(a[:, k]) i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] error = error | lax.eq(x, np._constant_like(a, 0)) a = ops.index_update(a, ops.index[:, k], np.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:]) a = a - np.where( (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype)) return pivot, perm, a, error pivot = np.zeros((min(m, n), ), dtype=np.int32) perm = np.arange(m, dtype=np.int32) error = np.array(False, np.bool_) if m == 0 and n == 0: # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a, error) return lax.fori_loop(0, min(m, n), body, (pivot, perm, a, error))
def _process_axis_index(self, frame): return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)