def funm(A, func, disp=True): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected square array_like input') T, Z = schur(A) T, Z = rsf2csf(T, Z) F = jnp.diag(func(jnp.diag(T))) F = F.astype(T.dtype.char) F, minden = _algorithm_11_1_1(F, T) F = Z @ F @ Z.conj().T if disp: return F if F.dtype.char.lower() == 'e': tol = jnp.finfo(jnp.float16).eps if F.dtype.char.lower() == 'f': tol = jnp.finfo(jnp.float32).eps else: tol = jnp.finfo(jnp.float64).eps minden = jnp.where(minden == 0.0, tol, minden) err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum( tol, (tol / minden) * norm(jnp.triu(T, 1), 1)))) return F, err
def _inner_loop(i, p_F_minden): p, F, minden = p_F_minden j = i+p s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1]) T_row, T_col = T[i-1], T[:, j-1] F_row, F_col = F[i-1], F[:, j-1] ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1) val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) - jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0)) s = s + val den = T[j-1, j-1] - T[i-1, i-1] s = jnp.where(den != 0, s / den, s) F = F.at[i-1, j-1].set(s) minden = jnp.minimum(minden, jnp.abs(den)) return p, F, minden
def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count