Esempio n. 1
0
def logpmf(k, p, loc=0):
  k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc)
  zero = lax._const(k, 0)
  one = lax._const(k, 1)
  x = lax.sub(k, loc)
  log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
  return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
                  -jnp.inf, log_probs)
Esempio n. 2
0
def log1m_exp(val):
    """Numerically stable implementation of `log(1 - exp(val))`."""
    return lax.cond(
        lax.gt(val, lax.log(2.0)),
        lambda _: lax.log(-lax.expm1(val)),
        lambda _: lax.log1p(-lax.exp(val)),
        operand=None,
    )
Esempio n. 3
0
File: beta.py Progetto: 0x0is1/jax
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = lax._const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Esempio n. 4
0
def _ndtr(x):
    """Implements ndtr core logic."""
    dtype = lax.dtype(x).type
    half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype)
    w = x * half_sqrt_2
    z = lax.abs(w)
    y = lax.select(
        lax.lt(z, half_sqrt_2),
        dtype(1.) + lax.erf(w),
        lax.select(lax.gt(w, dtype(0.)),
                   dtype(2.) - lax.erfc(z), lax.erfc(z)))
    return dtype(0.5) * y
Esempio n. 5
0
def logpmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.logpmf."""
    k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b,
                                            loc)
    y = lax.sub(lax.floor(k), loc)
    one = _lax_const(y, 1)
    zero = _lax_const(y, 0)
    combiln = lax.neg(
        lax.add(lax.log1p(n),
                betaln(lax.add(lax.sub(n, y), one), lax.add(y, one))))
    beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)),
                       betaln(a, b))
    log_probs = lax.add(combiln, beta_lns)
    y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
    log_probs = where(y_cond, -inf, log_probs)
    n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)),
                            lax.lt(b, zero))
    return where(n_a_b_cond, nan, log_probs)
Esempio n. 6
0
def log_ndtr(x, series_order=3):
    r"""Log Normal distribution function.

  For details of the Normal distribution function see `ndtr`.

  This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
  :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:

  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
    :math:`\log(1-x) \approx -x, x \ll 1`.
  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
    and take a log.
  - For `x <= lower_segment`, we use the series approximation of `erf` to compute
    the log CDF directly.

  The `lower_segment` is set based on the precision of the input:

  .. math::
    \begin{align}
    \mathit{lower\_segment} =&
      \ \begin{cases}
        -20 &  x.\mathrm{dtype}=\mathit{float64} \\
        -10 &  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases} \\
    \mathit{upper\_segment} =&
      \ \begin{cases}
        8&  x.\mathrm{dtype}=\mathit{float64} \\
        5&  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases}
    \end{align}


  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:

  .. math::
    \begin{align}
     \mathrm{ndtr}(x) =&\  \mathit{scale} * (1 + \mathit{sum}) + R_N \\
     \mathit{scale}   =&\  \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
     \mathit{sum}     =&\  \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
     R_N     =&\  O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
    \end{align}

  where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
  `double-factorial
  <https://en.wikipedia.org/wiki/Double_factorial>`_ operator.


  Args:
    x: an array of type `float32`, `float64`.
    series_order: Positive Python integer. Maximum depth to
      evaluate the asymptotic expansion. This is the `N` above.

  Returns:
    an array with `dtype=x.dtype`.

  Raises:
    TypeError: if `x.dtype` is not handled.
    TypeError: if `series_order` is a not Python `integer.`
    ValueError:  if `series_order` is not in `[0, 30]`.
  """
    if not isinstance(series_order, int):
        raise TypeError("series_order must be a Python integer.")
    if series_order < 0:
        raise ValueError("series_order must be non-negative.")
    if series_order > 30:
        raise ValueError("series_order must be <= 30.")

    x = jnp.asarray(x)
    dtype = lax.dtype(x)

    if dtype == jnp.float64:
        lower_segment = _LOGNDTR_FLOAT64_LOWER
        upper_segment = _LOGNDTR_FLOAT64_UPPER
    elif dtype == jnp.float32:
        lower_segment = _LOGNDTR_FLOAT32_LOWER
        upper_segment = _LOGNDTR_FLOAT32_UPPER
    else:
        raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype)))

    # The basic idea here was ported from:
    #   https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    # We copy the main idea, with a few changes
    # * For x >> 1, and X ~ Normal(0, 1),
    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
    #     which extends the range of validity of this function.
    # * We use one fixed series_order for all of 'x', rather than adaptive.
    # * Our docstring properly reflects that this is an asymptotic series, not a
    #   Taylor series. We also provided a correct bound on the remainder.
    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
    #   x=0. This happens even though the branch is unchosen because when x=0
    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
    #   regardless of whether dy is finite. Note that the minimum is a NOP if
    #   the branch is chosen.
    return jnp.where(
        lax.gt(x, upper_segment),
        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
        jnp.where(lax.gt(x, lower_segment),
                  lax.log(_ndtr(lax.max(x, lower_segment))),
                  _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
Esempio n. 7
0
def heaviside(x1, x2):
    _check_arraylike("heaviside", x1, x2)
    x1, x2 = _promote_dtypes_inexact(x1, x2)
    zero = _lax_const(x1, 0)
    return _where(lax.lt(x1, zero), zero,
                  _where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
Esempio n. 8
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
    log_probs = lax.neg(lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)