def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): def fn(x1, x2): x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() return _wraps(numpy_fn, lax_description=doc)(fn) else: return _wraps(numpy_fn)(fn)
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): if promote_to_inexact: fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) else: fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() return _wraps(numpy_fn, lax_description=doc)(fn) else: return _wraps(numpy_fn)(fn)
def _comparison_op(numpy_fn, lax_fn): # TODO(https://github.com/google/jax/issues/6713): decorate this function with # jit, after fixing a surprising interaction with remat(..., concrete=True). def fn(x1, x2): x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) # Comparison on complex types are defined as a lexicographic ordering on # the (real, imag) pair. if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): rx = lax.real(x1) ry = lax.real(x2) return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), lax_fn(rx, ry)) return lax_fn(x1, x2) return _wraps(numpy_fn)(fn)
def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) one = _lax_const(y, 1) zero = _lax_const(y, 0) combiln = lax.neg( lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n, y), one), lax.add(y, one)))) beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)), betaln(a, b)) log_probs = lax.add(combiln, beta_lns) y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) log_probs = where(y_cond, -inf, log_probs) n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) return where(n_a_b_cond, nan, log_probs) def pmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.pmf.""" return lax.exp(logpmf(k, n, a, b, loc)) # betabinom was added in scipy 1.4.0 if scipy_version >= (1, 4): logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf) pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf)
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) lax_fn = lax.shift_right_logical if \ np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) @_wraps(np.absolute) @partial(jit, inline=True) def absolute(x): _check_arraylike('absolute', x) dt = dtypes.dtype(x) return x if dt == np.bool_ or dtypes.issubdtype( dt, np.unsignedinteger) else lax.abs(x) abs = _wraps(np.abs)(absolute) @_wraps(np.rint) @jit def rint(x): _check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) @_wraps(np.sign)