示例#1
0
def _calc_P_Q(A):
    A = jnp.asarray(A)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError('expected A to be a square matrix')
    A_L1 = np_linalg.norm(A, 1)
    n_squarings = 0
    if A.dtype == 'float64' or A.dtype == 'complex128':
        U3, V3 = _pade3(A)
        U5, V5 = _pade5(A)
        U7, V7 = _pade7(A)
        U9, V9 = _pade9(A)
        maxnorm = 5.371920351148152
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings
        U13, V13 = _pade13(A)
        conds = jnp.array([
            1.495585217958292e-002, 2.539398330063230e-001,
            9.504178996162932e-001, 2.097847961257068e+000
        ])
        U = jnp.select((A_L1 < conds), (U3, U5, U7, U9), U13)
        V = jnp.select((A_L1 < conds), (V3, V5, V7, V9), V13)
    elif A.dtype == 'float32' or A.dtype == 'complex64':
        U3, V3 = _pade3(A)
        U5, V5 = _pade5(A)
        maxnorm = 3.925724783138660
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings
        U7, V7 = _pade7(A)
        conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
        U = jnp.select((A_L1 < conds), (U3, U5), U7)
        V = jnp.select((A_L1 < conds), (V3, V5), V7)
    else:
        raise TypeError("A.dtype={} is not supported.".format(A.dtype))
    P = U + V  # p_m(A) : numerator
    Q = -U + V  # q_m(A) : denominator
    return P, Q, n_squarings
示例#2
0
文件: poisson.py 项目: yashk2810/jax
def cdf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = jnp._constant_like(k, 0)
    x = lax.sub(k, loc)
    p = gammaincc(jnp.floor(1 + x), mu)
    return jnp.where(lax.lt(x, zero), zero, p)
示例#3
0
文件: ndimage.py 项目: gnecula/jax
def _linear_indices_and_weights(coordinate):
    lower = jnp.floor(coordinate)
    upper_weight = coordinate - lower
    lower_weight = 1 - upper_weight
    index = lower.astype(jnp.int32)
    return [(index, lower_weight), (index + 1, upper_weight)]