コード例 #1
0
ファイル: special.py プロジェクト: jbampton/jax
def expn(n, x):
    n, x = _promote_args_inexact("expn", n, x)
    _c = _lax_const
    zero = _c(x, 0)
    one = _c(x, 1)
    conds = [
        (n < _c(n, 0)) | (x < zero),
        (x == zero) & (n < _c(n, 2)),
        (x == zero) & (n >= _c(n, 2)),
        (n == _c(n, 0)) & (x >= zero),
        (n >= _c(n, 5000)),
        (x > one),
    ]
    n1 = jnp.where(n == _c(n, 1), n + n, n)
    vals = [
        jnp.nan,
        jnp.inf,
        one / n1,  # prevent div by zero
        jnp.exp(-x) / x,
        partial(_expn3, n),
        partial(_expn2, n),
        partial(_expn1, n),
    ]
    ret = jnp.piecewise(x, conds, vals)
    return ret
コード例 #2
0
ファイル: special.py プロジェクト: jbampton/jax
def _expi_pos(x):
    # x > 0
    _c = _lax_const
    conds = [(_c(x, 0) < x) & (x <= _c(x, 2))] + [(_c(x, 2**i) < x) &
                                                  (x <= _c(x, 2**(i + 1)))
                                                  for i in range(1, 6)]
    return jnp.piecewise(
        x,
        conds,
        [_expint1, _expint2, _expint3, _expint4, _expint5, _expint6, _expint7],
    )
コード例 #3
0
ファイル: special.py プロジェクト: jbampton/jax
def expi(x):
    (x, ) = _promote_args_inexact("expi", x)
    ret = jnp.piecewise(x, [x < 0], [lambda x: -exp1(-x), _expi_pos])
    return ret