def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A, 1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': U3, V3 = _pade3(A) U5, V5 = _pade5(A) U7, V7 = _pade7(A) U9, V9 = _pade9(A) maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U13, V13 = _pade13(A) conds = jnp.array([ 1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000 ]) U = jnp.select((A_L1 < conds), (U3, U5, U7, U9), U13) V = jnp.select((A_L1 < conds), (V3, V5, V7, V9), V13) elif A.dtype == 'float32' or A.dtype == 'complex64': U3, V3 = _pade3(A) U5, V5 = _pade5(A) maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U7, V7 = _pade7(A) conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000]) U = jnp.select((A_L1 < conds), (U3, U5), U7) V = jnp.select((A_L1 < conds), (V3, V5), V7) else: raise TypeError("A.dtype={} is not supported.".format(A.dtype)) P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings
def cdf(k, mu, loc=0): k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) zero = jnp._constant_like(k, 0) x = lax.sub(k, loc) p = gammaincc(jnp.floor(1 + x), mu) return jnp.where(lax.lt(x, zero), zero, p)
def _linear_indices_and_weights(coordinate): lower = jnp.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight index = lower.astype(jnp.int32) return [(index, lower_weight), (index + 1, upper_weight)]