Ejemplo n.º 1
def _sqrtm_triu(T):
  Implements Björck, Å., & Hammarling, S. (1983).
      "A Schur method for the square root of a matrix". Linear algebra and
      its applications", 52, 127-140.
    diag = jnp.sqrt(jnp.diag(T))
    n = diag.size
    U = jnp.diag(diag)

    def i_loop(l, data):
        j, U = data
        i = j - 1 - l
        s = lax.fori_loop(i + 1, j, lambda k, val: val + U[i, k] * U[k, j],
        value = jnp.where(T[i, j] == s, 0.0,
                          (T[i, j] - s) / (diag[i] + diag[j]))
        return j, U.at[i, j].set(value)

    def j_loop(j, U):
        _, U = lax.fori_loop(0, j, i_loop, (j, U))
        return U

    U = lax.fori_loop(0, n, j_loop, U)
    return U
Ejemplo n.º 2
def funm(A, func, disp=True):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected square array_like input')

  T, Z = schur(A)
  T, Z = rsf2csf(T, Z)

  F = jnp.diag(func(jnp.diag(T)))
  F = F.astype(T.dtype.char)

  F, minden = _algorithm_11_1_1(F, T)
  F = Z @ F @ Z.conj().T

  if disp:
    return F

  if F.dtype.char.lower() == 'e':
    tol = jnp.finfo(jnp.float16).eps
  if F.dtype.char.lower() == 'f':
    tol = jnp.finfo(jnp.float32).eps
    tol = jnp.finfo(jnp.float64).eps

  minden = jnp.where(minden == 0.0, tol, minden)
  err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
          tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

  return F, err
Ejemplo n.º 3
def _roots_no_zeros(p):
    # build companion matrix and find its eigenvalues (the roots)
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    return linalg.eigvals(A)
Ejemplo n.º 4
    def recursive_case(B, offset, b, agenda, blocks, eigenvectors):
        # The recursive case of the algorithm, specialized to a static block size
        # of B.
        H = _slice(blocks, (offset, 0), (b, b), (B, B))
        V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

        split_point = jnp.nanmedian(
            _mask(jnp.diag(jnp.real(H)), (b, ),
                  jnp.nan))  # TODO: Improve this?
        H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H,

        blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank))
        blocks = _update_slice(blocks, H_plus, (offset + rank, 0),
                               (b - rank, b - rank))
        eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset),
                                     (n, rank))
        eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank),
                                     (n, b - rank))

        agenda = agenda.push(_Subproblem(offset + rank, (b - rank)))
        agenda = agenda.push(_Subproblem(offset, rank))
        return agenda, blocks, eigenvectors
Ejemplo n.º 5
def _roots_no_zeros(p):
    # assume: p does not have leading zeros and has length > 1
    p, = _promote_dtypes_inexact(p)

    # build companion matrix and find its eigenvalues (the roots)
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    roots = linalg.eigvals(A)
    return roots