def _sqrtm_triu(T): """ Implements Björck, Å., & Hammarling, S. (1983). "A Schur method for the square root of a matrix". Linear algebra and its applications", 52, 127-140. """ diag = jnp.sqrt(jnp.diag(T)) n = diag.size U = jnp.diag(diag) def i_loop(l, data): j, U = data i = j - 1 - l s = lax.fori_loop(i + 1, j, lambda k, val: val + U[i, k] * U[k, j], 0.0) value = jnp.where(T[i, j] == s, 0.0, (T[i, j] - s) / (diag[i] + diag[j])) return j, U.at[i, j].set(value) def j_loop(j, U): _, U = lax.fori_loop(0, j, i_loop, (j, U)) return U U = lax.fori_loop(0, n, j_loop, U) return U
def funm(A, func, disp=True): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected square array_like input') T, Z = schur(A) T, Z = rsf2csf(T, Z) F = jnp.diag(func(jnp.diag(T))) F = F.astype(T.dtype.char) F, minden = _algorithm_11_1_1(F, T) F = Z @ F @ Z.conj().T if disp: return F if F.dtype.char.lower() == 'e': tol = jnp.finfo(jnp.float16).eps if F.dtype.char.lower() == 'f': tol = jnp.finfo(jnp.float32).eps else: tol = jnp.finfo(jnp.float64).eps minden = jnp.where(minden == 0.0, tol, minden) err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum( tol, (tol / minden) * norm(jnp.triu(T, 1), 1)))) return F, err
def _roots_no_zeros(p): # build companion matrix and find its eigenvalues (the roots) if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) A = diag(ones((p.size - 2, ), p.dtype), -1) A = A.at[0, :].set(-p[1:] / p[0]) return linalg.eigvals(A)
def recursive_case(B, offset, b, agenda, blocks, eigenvectors): # The recursive case of the algorithm, specialized to a static block size # of B. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) split_point = jnp.nanmedian( _mask(jnp.diag(jnp.real(H)), (b, ), jnp.nan)) # TODO: Improve this? H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H, b, split_point, V0=V) blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank)) blocks = _update_slice(blocks, H_plus, (offset + rank, 0), (b - rank, b - rank)) eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset), (n, rank)) eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank), (n, b - rank)) agenda = agenda.push(_Subproblem(offset + rank, (b - rank))) agenda = agenda.push(_Subproblem(offset, rank)) return agenda, blocks, eigenvectors
def _roots_no_zeros(p): # assume: p does not have leading zeros and has length > 1 p, = _promote_dtypes_inexact(p) # build companion matrix and find its eigenvalues (the roots) A = diag(ones((p.size - 2, ), p.dtype), -1) A = A.at[0, :].set(-p[1:] / p[0]) roots = linalg.eigvals(A) return roots