def _log_ndtr_lower(x, series_order): """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`.""" dtype = lax.dtype(x).type x_2 = lax.square(x) # Log of the term multiplying (1 + sum) log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi)) return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order))
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) dims = _reduction_dims(a, axis) dimadd = lambda x: lax.expand_dims(x, dims) amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_singletons = dimadd(amax) if b is None: out = lax.add( lax.log( lax.reduce(lax.exp(lax.sub(a, amax_singletons)), _constant_like(a, 0), lax.add, dims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b), _constant_like(a, 0), lax.add, dims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (dimadd(out), dimadd(sign)) if keepdims else (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return dimadd(out) if keepdims else out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) if b is None: out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b), axis=dims, keepdims=keepdims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return out
def log1m_exp(val): """Numerically stable implementation of `log(1 - exp(val))`.""" return lax.cond( lax.gt(val, lax.log(2.0)), lambda _: lax.log(-lax.expm1(val)), lambda _: lax.log1p(-lax.exp(val)), operand=None, )
def logpdf(x, a, loc=0, scale=1): x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale) one = _constant_like(x, 1) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.sub(lax.mul(lax.sub(a, one), lax.log(y)), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) return where(lax.lt(x, loc), -inf, log_probs)
def logpdf(x, b, loc=0, scale=1): x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _constant_like(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) log_probs = lax.neg( lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = _constant_like(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(lax.mul(lax.sub(a, one), lax.log(y)), lax.mul(lax.sub(b, one), lax.log1p(lax.neg(y)))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) one = _constant_like(x, 1) two = _constant_like(x, 2) y = lax.div(lax.sub(x, loc), scale) df_on_two = lax.div(df, two) kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) return where(lax.lt(x, loc), -inf, log_probs)
def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) zero = lax._const(k, 0) one = lax._const(k, 1) x = lax.sub(k, loc) log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) two = _constant_like(x, 2) scale_sqrd = lax.pow(scale, two) log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd)) quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd) return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
def tridiagonal_pos_def_log_det(a, b): """Compute the log-determinant of a tridiagonal positive-definite matrix. Computes the log-determinant for a tridiagonal matrix with main diagonal `b` (all positive) and lower- / upper- diagonal `a`. Equivalent to np.linalg.slogdet(diag(a, -1) + diag(b) + diag(a, 1))[1] Args: a (array): lower/upper diagonal of matrix, shape `(dim - 1,)`. b (array): main diagonal of matrix, shape `(dim,)`, all positive. Returns: Scalar corresponding to log-determinant. """ def log_continuant_recursion(l_i_and_l_i_minus_1, a_i_and_b_i_plus_1): l_i, l_i_minus_1 = l_i_and_l_i_minus_1 a_i, b_i_plus_1 = a_i_and_b_i_plus_1 l_i_plus_1 = log_diff_exp( lax.log(b_i_plus_1) + l_i, 2 * lax.log(abs(a_i)) + l_i_minus_1) return (l_i_plus_1, l_i), None (l_n, _), _ = lax.scan(log_continuant_recursion, (lax.log(b[0]), 0), (a, b[1:])) return l_n
def logaddexp(x1, x2): x1, x2 = _promote_to_result_dtype(onp.logaddexp, *_promote_shapes(x1, x2)) amax = lax.max(x1, x2) return lax.add( amax, lax.log(lax.add(lax.exp(lax.sub(x1, amax)), lax.exp(lax.sub(x2, amax)))))
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _constant_like(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg( lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
def _log_taylor(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series v = [lax.log(x)] + [None] * len(series) for k in range(1, len(v)): conv = sum([_scale(k, j) * v[j] * u[k - j] for j in range(1, k)]) v[k] = (u[k] - fact(k - 1) * conv) / u[0] primal_out, *series_out = v return primal_out, series_out
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d = _promote_args_inexact("multigammaln", a, d) constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def _pow_taylor(primals_in, series_in): u_, r_ = primals_in x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in) u = [x] + series v = [u_ ** r_] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)]) primal_out, *series_out = v return primal_out, series_out
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale) two = _lax_const(x, 2) scaled_x = lax.div(lax.sub(x, loc), scale) df_over_two = lax.div(df, two) df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5)) normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi)) normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two) normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp), lax.lgamma(df_plus_one_over_two)) quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
def multigammaln(a, d): a, = _promote_args_inexact("multigammaln", a) d = lax.convert_element_type(d, lax.dtype(a)) constant = lax.mul( lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = _promote_args_inexact("multigammaln", a, d) constant = lax.mul( lax.mul(lax.mul(_lax_const(a, 0.25), d_), lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi))) b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2)) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) return res + constant
def log_I1(orders: int, value, terms=250): r"""Compute first n log modified bessel function of first kind .. math :: \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk) - \lgamma(v + k + 1)\right]) :param orders: orders of the log modified bessel function. :param value: values to compute modified bessel function for :param terms: truncation of summation :return: 0 to orders modified bessel function """ orders = orders + 1 if value.ndim == 0: vshape = jnp.shape([1]) else: vshape = value.shape value = value.reshape(-1, 1) flat_vshape = _numel(vshape) k = jnp.arange(terms) lgammas_all = lax.lgamma(jnp.arange(1.0, terms + orders + 1)) assert lgammas_all.shape == (orders + terms, ) # lgamma(0) = inf => start from 1 lvalues = lax.log(value / 2) * k.reshape(1, -1) assert lvalues.shape == (flat_vshape, terms) lfactorials = lgammas_all[:terms] assert lfactorials.shape == (terms, ) lgammas = lgammas_all.tile(orders).reshape((orders, -1)) assert lgammas.shape == (orders, terms + orders ) # lgamma(0) = inf => start from 1 indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1) assert indices.shape == (orders, terms) seqs = logsumexp( 2 * lvalues[None, :, :] - lfactorials[None, None, :] - jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :], -1, ) assert seqs.shape == (orders, flat_vshape) i1s = lvalues[..., :orders].T + seqs assert i1s.shape == (orders, flat_vshape) return i1s.reshape(-1, *vshape)
def logpdf(x, p): x, p = _promote_args_inexact("gennorm.logpdf", x, p) return lax.log(.5 * p) - lax.lgamma(1 / p) - lax.abs(x)**p
def log_ndtr(x, series_order=3): r"""Log Normal distribution function. For details of the Normal distribution function see `ndtr`. This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically: - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on :math:`\log(1-x) \approx -x, x \ll 1`. - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique and take a log. - For `x <= lower_segment`, we use the series approximation of `erf` to compute the log CDF directly. The `lower_segment` is set based on the precision of the input: .. math:: \begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align} When `x < lower_segment`, the `ndtr` asymptotic series approximation is: .. math:: \begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align} where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a `double-factorial <https://en.wikipedia.org/wiki/Double_factorial>`_ operator. Args: x: an array of type `float32`, `float64`. series_order: Positive Python integer. Maximum depth to evaluate the asymptotic expansion. This is the `N` above. Returns: an array with `dtype=x.dtype`. Raises: TypeError: if `x.dtype` is not handled. TypeError: if `series_order` is a not Python `integer.` ValueError: if `series_order` is not in `[0, 30]`. """ if not isinstance(series_order, int): raise TypeError("series_order must be a Python integer.") if series_order < 0: raise ValueError("series_order must be non-negative.") if series_order > 30: raise ValueError("series_order must be <= 30.") x = jnp.asarray(x) dtype = lax.dtype(x) if dtype == jnp.float64: lower_segment = _LOGNDTR_FLOAT64_LOWER upper_segment = _LOGNDTR_FLOAT64_UPPER elif dtype == jnp.float32: lower_segment = _LOGNDTR_FLOAT32_LOWER upper_segment = _LOGNDTR_FLOAT32_UPPER else: raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype))) # The basic idea here was ported from: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html # We copy the main idea, with a few changes # * For x >> 1, and X ~ Normal(0, 1), # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x], # which extends the range of validity of this function. # * We use one fixed series_order for all of 'x', rather than adaptive. # * Our docstring properly reflects that this is an asymptotic series, not a # Taylor series. We also provided a correct bound on the remainder. # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when # x=0. This happens even though the branch is unchosen because when x=0 # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan # regardless of whether dy is finite. Note that the minimum is a NOP if # the branch is chosen. return jnp.where( lax.gt(x, upper_segment), -_ndtr(-x), # log(1-x) ~= -x, x << 1 jnp.where(lax.gt(x, lower_segment), lax.log(_ndtr(lax.max(x, lower_segment))), _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
def xlogy(x, y): x, y = _promote_args_inexact("xlogy", x, y) x_ok = x != 0. safe_x = jnp.where(x_ok, x, 1.) safe_y = jnp.where(x_ok, y, 1.) return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
def xlogy_jvp_lhs(g, x, y, jaxpr, aval, consts): x, y = _promote_args_like(osp_special.xlogy, x, y) g, y = _promote_args_like(osp_special.xlogy, g, y) return lax._safe_mul(lax._brcast(g, y), lax._brcast(lax.log(y), g))
def_deriv( lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x**0.5) def_comp(lax.rsqrt_p, lambda x: x**-0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x))) def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x)) def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y)) def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b)) def _erf_inv_rule(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series
def _ndtri(p): """Implements ndtri core logic.""" # Constants used in piece-wise rational approximations. Taken from the cephes # library: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html p0 = list( reversed([ -5.99633501014107895267E1, 9.80010754185999661536E1, -5.66762857469070293439E1, 1.39312609387279679503E1, -1.23916583867381258016E0 ])) q0 = list( reversed([ 1.0, 1.95448858338141759834E0, 4.67627912898881538453E0, 8.63602421390890590575E1, -2.25462687854119370527E2, 2.00260212380060660359E2, -8.20372256168333339912E1, 1.59056225126211695515E1, -1.18331621121330003142E0 ])) p1 = list( reversed([ 4.05544892305962419923E0, 3.15251094599893866154E1, 5.71628192246421288162E1, 4.40805073893200834700E1, 1.46849561928858024014E1, 2.18663306850790267539E0, -1.40256079171354495875E-1, -3.50424626827848203418E-2, -8.57456785154685413611E-4 ])) q1 = list( reversed([ 1.0, 1.57799883256466749731E1, 4.53907635128879210584E1, 4.13172038254672030440E1, 1.50425385692907503408E1, 2.50464946208309415979E0, -1.42182922854787788574E-1, -3.80806407691578277194E-2, -9.33259480895457427372E-4 ])) p2 = list( reversed([ 3.23774891776946035970E0, 6.91522889068984211695E0, 3.93881025292474443415E0, 1.33303460815807542389E0, 2.01485389549179081538E-1, 1.23716634817820021358E-2, 3.01581553508235416007E-4, 2.65806974686737550832E-6, 6.23974539184983293730E-9 ])) q2 = list( reversed([ 1.0, 6.02427039364742014255E0, 3.67983563856160859403E0, 1.37702099489081330271E0, 2.16236993594496635890E-1, 1.34204006088543189037E-2, 3.28014464682127739104E-4, 2.89247864745380683936E-6, 6.79019408009981274425E-9 ])) dtype = lax.dtype(p).type shape = jnp.shape(p) def _create_polynomial(var, coeffs): """Compute n_th order polynomial via Horner's method.""" coeffs = np.array(coeffs, dtype) if not coeffs.size: return jnp.zeros_like(var) return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p) # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs # later on. The result from the computation when p == 0 is not used so any # number that doesn't result in NaNs is fine. sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.), jnp.full(shape, dtype(0.5)), maybe_complement_p) # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). w = sanitized_mcp - dtype(0.5) ww = lax.square(w) x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0)) x_for_big_p *= -dtype(np.sqrt(2. * np.pi)) # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different # arrays based on whether p < exp(-32). z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp)) first_term = z - lax.log(z) / z second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) / _create_polynomial(dtype(1.) / z, q2) / z) second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) / _create_polynomial(dtype(1.) / z, q1) / z) x_for_small_p = first_term - second_term_small_p x_otherwise = first_term - second_term_otherwise x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p, jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise)) x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) infinity = jnp.full(shape, dtype(np.inf)) x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity, jnp.where(p >= dtype(1.0), infinity, x)) return x_nan_replaced
def exp2(x): x, = _promote_args_inexact("exp2", x) return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
def logit(x): x = asarray(x) return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
def log10(x): x, = _promote_args_inexact("log10", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))