Beispiel #1
0
def _calc_P_Q(A):
    A = jnp.asarray(A)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError('expected A to be a square matrix')
    A_L1 = np_linalg.norm(A, 1)
    n_squarings = 0
    if A.dtype == 'float64' or A.dtype == 'complex128':
        maxnorm = 5.371920351148152
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([
            1.495585217958292e-002, 2.539398330063230e-001,
            9.504178996162932e-001, 2.097847961257068e+000
        ],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A)
    elif A.dtype == 'float32' or A.dtype == 'complex64':
        maxnorm = 3.925724783138660
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A)
    else:
        raise TypeError(f"A.dtype={A.dtype} is not supported.")
    P = U + V  # p_m(A) : numerator
    Q = -U + V  # q_m(A) : denominator
    return P, Q, n_squarings
Beispiel #2
0
def _calc_P_Q(A):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected A to be a square matrix')
  A_L1 = np_linalg.norm(A,1)
  n_squarings = 0
  if A.dtype == 'float64' or A.dtype == 'complex128':
   U3, V3 = _pade3(A)
   U5, V5 = _pade5(A)
   U7, V7 = _pade7(A)
   U9, V9 = _pade9(A)
   maxnorm = 5.371920351148152
   n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
   A = A / 2**n_squarings
   U13, V13 = _pade13(A)
   conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
                    9.504178996162932e-001, 2.097847961257068e+000])
   U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13)
   V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13)
  elif A.dtype == 'float32' or A.dtype == 'complex64':
    U3,V3 = _pade3(A)
    U5,V5 = _pade5(A)
    maxnorm = 3.925724783138660
    n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
    A = A / 2**n_squarings
    U7,V7 = _pade7(A)
    conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
    U = jnp.select((A_L1<conds), (U3, U5), U7)
    V = jnp.select((A_L1<conds), (V3, V5), V7)
  else:
    raise TypeError("A.dtype={} is not supported.".format(A.dtype))
  P = U + V  # p_m(A) : numerator
  Q = -U + V # q_m(A) : denominator
  return P, Q, n_squarings