예제 #1
0
파일: linalg.py 프로젝트: xueeinstein/jax
def funm(A, func, disp=True):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected square array_like input')

  T, Z = schur(A)
  T, Z = rsf2csf(T, Z)

  F = jnp.diag(func(jnp.diag(T)))
  F = F.astype(T.dtype.char)

  F, minden = _algorithm_11_1_1(F, T)
  F = Z @ F @ Z.conj().T

  if disp:
    return F

  if F.dtype.char.lower() == 'e':
    tol = jnp.finfo(jnp.float16).eps
  if F.dtype.char.lower() == 'f':
    tol = jnp.finfo(jnp.float32).eps
  else:
    tol = jnp.finfo(jnp.float64).eps

  minden = jnp.where(minden == 0.0, tol, minden)
  err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
          tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

  return F, err
예제 #2
0
파일: linalg.py 프로젝트: xueeinstein/jax
 def _inner_loop(i, p_F_minden):
   p, F, minden = p_F_minden
   j = i+p
   s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1])
   T_row, T_col = T[i-1], T[:, j-1]
   F_row, F_col = F[i-1], F[:, j-1]
   ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1)
   val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) -
           jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0))
   s = s + val
   den = T[j-1, j-1] - T[i-1, i-1]
   s = jnp.where(den != 0, s / den, s)
   F = F.at[i-1, j-1].set(s)
   minden = jnp.minimum(minden, jnp.abs(den))
   return p, F, minden
예제 #3
0
파일: linalg.py 프로젝트: xueeinstein/jax
 def sturm_step(i, q, count):
     q = alpha[i] - beta_sq[i - 1] / q - x
     count = jnp.where(q <= pivmin, count + 1, count)
     q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
     return q, count