def body(k, state): pivot, perm, a = state m_idx = np.arange(m) n_idx = np.arange(n) if np.issubdtype(a.dtype, np.complexfloating): t = a[:, k] magnitude = np.abs(np.real(t)) + np.abs(np.imag(t)) else: magnitude = np.abs(a[:, k]) i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], np.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:]) a = a - np.where( (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype)) return pivot, perm, a
def _lu_blocked(a, block_size=32): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) pivot = np.zeros((r, ), dtype=np.int32) error = np.array(False, np.bool_) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:, k:k + b]) error = error | block_error a = ops.index_update(a, ops.index[k:, k:k + b], lu_block) a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k]) pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k) if k + b < n: a = ops.index_update(a, ops.index[k:, k + b:], a[perm + k, k + b:]) a = ops.index_update( a, ops.index[k:k + b, k + b:], triangular_solve(a[k:k + b, k:k + b], a[k:k + b, k + b:], left_side=True, lower=True, unit_diagonal=True)) a = ops.index_add( a, ops.index[k + b:, k + b:], -lax.dot(a[k + b:, k:k + b], a[k:k + b, k + b:], precision=lax.Precision.HIGHEST)) a = np.where(error, lax.full_like(a, np.nan), a) return pivot, a
def cond(x, p=None): _assertNoEmpty2d(x) if p in (None, 2): s = la.svd(x, compute_uv=False) return s[..., 0] / s[..., -1] elif p == -2: s = la.svd(x, compute_uv=False) r = s[..., -1] / s[..., 0] else: _assertRankAtLeast2(x) _assertNdSquareness(x) invx = la.inv(x) r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(invx, ord=p, axis=(-2, -1)) # Convert nans to infs unless the original array had nan entries orig_nan_check = np.full_like(r, ~np.isnan(r).any()) nan_mask = np.logical_and(np.isnan(r), ~np.isnan(x).any(axis=(-2, -1))) r = np.where(orig_nan_check, np.where(nan_mask, np.inf, r), r) return r