def expn(n, x): n, x = _promote_args_inexact("expn", n, x) _c = _lax_const zero = _c(x, 0) one = _c(x, 1) conds = [ (n < _c(n, 0)) | (x < zero), (x == zero) & (n < _c(n, 2)), (x == zero) & (n >= _c(n, 2)), (n == _c(n, 0)) & (x >= zero), (n >= _c(n, 5000)), (x > one), ] n1 = jnp.where(n == _c(n, 1), n + n, n) vals = [ jnp.nan, jnp.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, partial(_expn3, n), partial(_expn2, n), partial(_expn1, n), ] ret = jnp.piecewise(x, conds, vals) return ret
def _expi_pos(x): # x > 0 _c = _lax_const conds = [(_c(x, 0) < x) & (x <= _c(x, 2))] + [(_c(x, 2**i) < x) & (x <= _c(x, 2**(i + 1))) for i in range(1, 6)] return jnp.piecewise( x, conds, [_expint1, _expint2, _expint3, _expint4, _expint5, _expint6, _expint7], )
def expi(x): (x, ) = _promote_args_inexact("expi", x) ret = jnp.piecewise(x, [x < 0], [lambda x: -exp1(-x), _expi_pos]) return ret