def _update_T_Z(m, T, Z): mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1), (2, 2))) - T[m, m] r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype) c = mu[0] / r s = T[m, m - 1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:] T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0) col_mask = jnp.arange(N) >= m - 1 G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0) T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols) T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0) # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1) row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1) # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1) Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m - 1, axis=1) return T, Z
def funm(A, func, disp=True): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected square array_like input') T, Z = schur(A) T, Z = rsf2csf(T, Z) F = jnp.diag(func(jnp.diag(T))) F = F.astype(T.dtype.char) F, minden = _algorithm_11_1_1(F, T) F = Z @ F @ Z.conj().T if disp: return F if F.dtype.char.lower() == 'e': tol = jnp.finfo(jnp.float16).eps if F.dtype.char.lower() == 'f': tol = jnp.finfo(jnp.float32).eps else: tol = jnp.finfo(jnp.float64).eps minden = jnp.where(minden == 0.0, tol, minden) err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum( tol, (tol / minden) * norm(jnp.triu(T, 1), 1)))) return F, err
def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A, 1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings.astype(A.dtype) conds = jnp.array([ 1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000 ], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A) elif A.dtype == 'float32' or A.dtype == 'complex64': maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings.astype(A.dtype) conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A) else: raise TypeError(f"A.dtype={A.dtype} is not supported.") P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings
def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A,1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': U3, V3 = _pade3(A) U5, V5 = _pade5(A) U7, V7 = _pade7(A) U9, V9 = _pade9(A) maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U13, V13 = _pade13(A) conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000]) U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13) V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13) elif A.dtype == 'float32' or A.dtype == 'complex64': U3,V3 = _pade3(A) U5,V5 = _pade5(A) maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U7,V7 = _pade7(A) conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000]) U = jnp.select((A_L1<conds), (U3, U5), U7) V = jnp.select((A_L1<conds), (V3, V5), V7) else: raise TypeError("A.dtype={} is not supported.".format(A.dtype)) P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings
def cond(x, p=None): _assertNoEmpty2d(x) if p in (None, 2): s = la.svd(x, compute_uv=False) return s[..., 0] / s[..., -1] elif p == -2: s = la.svd(x, compute_uv=False) r = s[..., -1] / s[..., 0] else: _assertRankAtLeast2(x) _assertNdSquareness(x) invx = la.inv(x) r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm( invx, ord=p, axis=(-2, -1)) # Convert nans to infs unless the original array had nan entries orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any()) nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1))) r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r) return r
def body_f_after_matmul(X): Q, _ = jnp_linalg.qr(X, mode="complete") # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H) error_matrix = jnp.dot(error_matrix, V1) error = jnp_linalg.norm(error_matrix) / H_norm return V1, V2, error
def vq(obs, code_book, check_finite=True): _check_arraylike("scipy.cluster.vq.vq", obs, code_book) if obs.ndim != code_book.ndim: raise ValueError("Observation and code_book should have the same rank") obs, code_book = _promote_dtypes_inexact(obs, code_book) if obs.ndim == 1: obs, code_book = obs[..., None], code_book[..., None] if obs.ndim != 2: raise ValueError("ndim different than 1 or 2 are not supported") # explicitly rank promotion dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs) code = argmin(dist, axis=-1) dist_min = vmap(operator.getitem)(dist, code) return code, dist_min
def _projector_subspace(P, H, n, rank, maxiter=2): """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T` and an `n x (n - rank)` isometry `V_minus` such that -(I - P) = V_plus @ V_plus.conj().T`. The subspaces are computed using the naiive QR eigendecomposition algorithm, which converges very quickly due to the sharp separation between the relevant eigenvalues of the projector. Args: P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank` eigenpairs. `P` is padded to NxN. H: The aforementioned Hermitian matrix, which is used to track convergence. n: the true (dynamic) shape of `P`. rank: Rank of `P`. maxiter: Maximum number of iterations. Returns: V_minus, V_plus: Isometries into the eigenspaces described in the docstring. """ # Choose an initial guess: the `rank` largest-norm columns of P. N, _ = P.shape column_norms = jnp_linalg.norm(P, axis=1) # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN. column_norms = _mask(column_norms, (n,), jnp.nan) sort_idxs = jnp.argsort(column_norms) X = P[:, sort_idxs] # X = X[:, :rank] X = _mask(X, (n, rank)) H_norm = jnp_linalg.norm(H) thresh = 10 * jnp.finfo(X.dtype).eps * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): Q, _ = jnp_linalg.qr(X, mode="complete") # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H) error_matrix = jnp.dot(error_matrix, V1) error = jnp_linalg.norm(error_matrix) / H_norm return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh return jnp.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args X = jnp.dot(P, V1) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) one = jnp.ones(1, dtype=jnp.int32) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) return V1, V2