def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A, 1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings.astype(A.dtype) conds = jnp.array([ 1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000 ], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A) elif A.dtype == 'float32' or A.dtype == 'complex64': maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings.astype(A.dtype) conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A) else: raise TypeError(f"A.dtype={A.dtype} is not supported.") P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings
def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A,1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': U3, V3 = _pade3(A) U5, V5 = _pade5(A) U7, V7 = _pade7(A) U9, V9 = _pade9(A) maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U13, V13 = _pade13(A) conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000]) U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13) V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13) elif A.dtype == 'float32' or A.dtype == 'complex64': U3,V3 = _pade3(A) U5,V5 = _pade5(A) maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U7,V7 = _pade7(A) conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000]) U = jnp.select((A_L1<conds), (U3, U5), U7) V = jnp.select((A_L1<conds), (V3, V5), V7) else: raise TypeError("A.dtype={} is not supported.".format(A.dtype)) P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings