def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = lax._const(k, 0) one = lax._const(k, 1) x = lax.sub(k, loc) log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf, log_probs)
def log1m_exp(val): """Numerically stable implementation of `log(1 - exp(val))`.""" return lax.cond( lax.gt(val, lax.log(2.0)), lambda _: lax.log(-lax.expm1(val)), lambda _: lax.log1p(-lax.exp(val)), operand=None, )
def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = lax._const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)
def _ndtr(x): """Implements ndtr core logic.""" dtype = lax.dtype(x).type half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype) w = x * half_sqrt_2 z = lax.abs(w) y = lax.select( lax.lt(z, half_sqrt_2), dtype(1.) + lax.erf(w), lax.select(lax.gt(w, dtype(0.)), dtype(2.) - lax.erfc(z), lax.erfc(z))) return dtype(0.5) * y
def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) one = _lax_const(y, 1) zero = _lax_const(y, 0) combiln = lax.neg( lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n, y), one), lax.add(y, one)))) beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)), betaln(a, b)) log_probs = lax.add(combiln, beta_lns) y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) log_probs = where(y_cond, -inf, log_probs) n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) return where(n_a_b_cond, nan, log_probs)
def log_ndtr(x, series_order=3): r"""Log Normal distribution function. For details of the Normal distribution function see `ndtr`. This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically: - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on :math:`\log(1-x) \approx -x, x \ll 1`. - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique and take a log. - For `x <= lower_segment`, we use the series approximation of `erf` to compute the log CDF directly. The `lower_segment` is set based on the precision of the input: .. math:: \begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align} When `x < lower_segment`, the `ndtr` asymptotic series approximation is: .. math:: \begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align} where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a `double-factorial <https://en.wikipedia.org/wiki/Double_factorial>`_ operator. Args: x: an array of type `float32`, `float64`. series_order: Positive Python integer. Maximum depth to evaluate the asymptotic expansion. This is the `N` above. Returns: an array with `dtype=x.dtype`. Raises: TypeError: if `x.dtype` is not handled. TypeError: if `series_order` is a not Python `integer.` ValueError: if `series_order` is not in `[0, 30]`. """ if not isinstance(series_order, int): raise TypeError("series_order must be a Python integer.") if series_order < 0: raise ValueError("series_order must be non-negative.") if series_order > 30: raise ValueError("series_order must be <= 30.") x = jnp.asarray(x) dtype = lax.dtype(x) if dtype == jnp.float64: lower_segment = _LOGNDTR_FLOAT64_LOWER upper_segment = _LOGNDTR_FLOAT64_UPPER elif dtype == jnp.float32: lower_segment = _LOGNDTR_FLOAT32_LOWER upper_segment = _LOGNDTR_FLOAT32_UPPER else: raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype))) # The basic idea here was ported from: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html # We copy the main idea, with a few changes # * For x >> 1, and X ~ Normal(0, 1), # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x], # which extends the range of validity of this function. # * We use one fixed series_order for all of 'x', rather than adaptive. # * Our docstring properly reflects that this is an asymptotic series, not a # Taylor series. We also provided a correct bound on the remainder. # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when # x=0. This happens even though the branch is unchosen because when x=0 # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan # regardless of whether dy is finite. Note that the minimum is a NOP if # the branch is chosen. return jnp.where( lax.gt(x, upper_segment), -_ndtr(-x), # log(1-x) ~= -x, x << 1 jnp.where(lax.gt(x, lower_segment), lax.log(_ndtr(lax.max(x, lower_segment))), _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
def heaviside(x1, x2): _check_arraylike("heaviside", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, _where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale) log_probs = lax.neg(lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)