コード例 #1
0
ファイル: special.py プロジェクト: zizai/jax
def _log_ndtr_lower(x, series_order):
  """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
  dtype = lax.dtype(x).type
  x_2 = lax.square(x)
  # Log of the term multiplying (1 + sum)
  log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi))
  return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order))
コード例 #2
0
ファイル: special.py プロジェクト: xeransis/jax
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
コード例 #3
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
コード例 #4
0
def log1m_exp(val):
    """Numerically stable implementation of `log(1 - exp(val))`."""
    return lax.cond(
        lax.gt(val, lax.log(2.0)),
        lambda _: lax.log(-lax.expm1(val)),
        lambda _: lax.log1p(-lax.exp(val)),
        operand=None,
    )
コード例 #5
0
ファイル: gamma.py プロジェクト: jamestwebber/jax
def logpdf(x, a, loc=0, scale=1):
  x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
  one = _constant_like(x, 1)
  y = lax.div(lax.sub(x, loc), scale)
  log_linear_term = lax.sub(lax.mul(lax.sub(a, one), lax.log(y)), y)
  shape_terms = lax.add(gammaln(a), lax.log(scale))
  log_probs = lax.sub(log_linear_term, shape_terms)
  return where(lax.lt(x, loc), -inf, log_probs)
コード例 #6
0
ファイル: pareto.py プロジェクト: yashk2810/jax
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = _constant_like(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
コード例 #7
0
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = _constant_like(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(lax.mul(lax.sub(a, one), lax.log(y)),
                              lax.mul(lax.sub(b, one), lax.log1p(lax.neg(y))))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
コード例 #8
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
コード例 #9
0
def logpmf(k, p, loc=0):
    k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc)
    zero = lax._const(k, 0)
    one = lax._const(k, 1)
    x = lax.sub(k, loc)
    log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
    return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
コード例 #10
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    scale_sqrd = lax.pow(scale, two)
    log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
    return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
コード例 #11
0
def tridiagonal_pos_def_log_det(a, b):
    """Compute the log-determinant of a tridiagonal positive-definite matrix.

    Computes the log-determinant for a tridiagonal matrix with main diagonal
    `b` (all positive) and lower- / upper- diagonal `a`. Equivalent to

        np.linalg.slogdet(diag(a, -1) + diag(b) + diag(a, 1))[1]

    Args:
        a (array): lower/upper diagonal of matrix, shape `(dim - 1,)`.
        b (array): main diagonal of matrix, shape `(dim,)`, all positive.

    Returns:
        Scalar corresponding to log-determinant.
    """
    def log_continuant_recursion(l_i_and_l_i_minus_1, a_i_and_b_i_plus_1):
        l_i, l_i_minus_1 = l_i_and_l_i_minus_1
        a_i, b_i_plus_1 = a_i_and_b_i_plus_1
        l_i_plus_1 = log_diff_exp(
            lax.log(b_i_plus_1) + l_i, 2 * lax.log(abs(a_i)) + l_i_minus_1)
        return (l_i_plus_1, l_i), None

    (l_n, _), _ = lax.scan(log_continuant_recursion, (lax.log(b[0]), 0),
                           (a, b[1:]))

    return l_n
コード例 #12
0
def logaddexp(x1, x2):
    x1, x2 = _promote_to_result_dtype(onp.logaddexp, *_promote_shapes(x1, x2))
    amax = lax.max(x1, x2)
    return lax.add(
        amax,
        lax.log(lax.add(lax.exp(lax.sub(x1, amax)), lax.exp(lax.sub(x2,
                                                                    amax)))))
コード例 #13
0
ファイル: cauchy.py プロジェクト: yashk2810/jax
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _constant_like(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
コード例 #14
0
ファイル: jet.py プロジェクト: 0x0is1/jax
def _log_taylor(primals_in, series_in):
    x, = primals_in
    series, = series_in
    u = [x] + series
    v = [lax.log(x)] + [None] * len(series)
    for k in range(1, len(v)):
        conv = sum([_scale(k, j) * v[j] * u[k - j] for j in range(1, k)])
        v[k] = (u[k] - fact(k - 1) * conv) / u[0]
    primal_out, *series_out = v
    return primal_out, series_out
コード例 #15
0
ファイル: special.py プロジェクト: zizai/jax
def multigammaln(a, d):
  d = core.concrete_or_error(int, d, "d argument of multigammaln")
  a, d = _promote_args_inexact("multigammaln", a, d)

  constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
                             lax.sub(d, _constant_like(a, 1))),
                     lax.log(_constant_like(a, np.pi)))
  res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
                        lax.div(jnp.arange(d), _constant_like(a, 2))),
               axis=-1)
  return res + constant
コード例 #16
0
ファイル: special.py プロジェクト: jbampton/jax
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out
コード例 #17
0
ファイル: jet.py プロジェクト: MichaelMarien/jax
def _pow_taylor(primals_in, series_in):
  u_, r_ = primals_in

  x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)

  u = [x] + series
  v = [u_ ** r_] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)])
  primal_out, *series_out = v

  return primal_out, series_out
コード例 #18
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
コード例 #19
0
ファイル: special.py プロジェクト: xeransis/jax
def multigammaln(a, d):
    a, = _promote_args_inexact("multigammaln", a)
    d = lax.convert_element_type(d, lax.dtype(a))
    constant = lax.mul(
        lax.mul(lax.mul(_constant_like(a, 0.25), d),
                lax.sub(d, _constant_like(a, 1))),
        lax.log(_constant_like(a, np.pi)))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        lax.div(jnp.arange(d), _constant_like(a, 2))),
                  axis=-1)
    return res + constant
コード例 #20
0
ファイル: special.py プロジェクト: jbampton/jax
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(_lax_const(a, 0.25), d_),
                lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi)))
    b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                  axis=-1)
    return res + constant
コード例 #21
0
ファイル: directional.py プロジェクト: ahmadsalim/numpyro
def log_I1(orders: int, value, terms=250):
    r"""Compute first n log modified bessel function of first kind
    .. math ::

        \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk)
        - \lgamma(v + k + 1)\right])

    :param orders: orders of the log modified bessel function.
    :param value: values to compute modified bessel function for
    :param terms: truncation of summation
    :return: 0 to orders modified bessel function
    """
    orders = orders + 1
    if value.ndim == 0:
        vshape = jnp.shape([1])
    else:
        vshape = value.shape
    value = value.reshape(-1, 1)
    flat_vshape = _numel(vshape)

    k = jnp.arange(terms)
    lgammas_all = lax.lgamma(jnp.arange(1.0, terms + orders + 1))
    assert lgammas_all.shape == (orders + terms,
                                 )  # lgamma(0) = inf => start from 1

    lvalues = lax.log(value / 2) * k.reshape(1, -1)
    assert lvalues.shape == (flat_vshape, terms)

    lfactorials = lgammas_all[:terms]
    assert lfactorials.shape == (terms, )

    lgammas = lgammas_all.tile(orders).reshape((orders, -1))
    assert lgammas.shape == (orders, terms + orders
                             )  # lgamma(0) = inf => start from 1

    indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1)
    assert indices.shape == (orders, terms)

    seqs = logsumexp(
        2 * lvalues[None, :, :] - lfactorials[None, None, :] -
        jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :],
        -1,
    )
    assert seqs.shape == (orders, flat_vshape)

    i1s = lvalues[..., :orders].T + seqs
    assert i1s.shape == (orders, flat_vshape)
    return i1s.reshape(-1, *vshape)
コード例 #22
0
def logpdf(x, p):
    x, p = _promote_args_inexact("gennorm.logpdf", x, p)
    return lax.log(.5 * p) - lax.lgamma(1 / p) - lax.abs(x)**p
コード例 #23
0
ファイル: special.py プロジェクト: GregCT/jax
def log_ndtr(x, series_order=3):
    r"""Log Normal distribution function.

  For details of the Normal distribution function see `ndtr`.

  This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
  :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:

  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
    :math:`\log(1-x) \approx -x, x \ll 1`.
  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
    and take a log.
  - For `x <= lower_segment`, we use the series approximation of `erf` to compute
    the log CDF directly.

  The `lower_segment` is set based on the precision of the input:

  .. math::
    \begin{align}
    \mathit{lower\_segment} =&
      \ \begin{cases}
        -20 &  x.\mathrm{dtype}=\mathit{float64} \\
        -10 &  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases} \\
    \mathit{upper\_segment} =&
      \ \begin{cases}
        8&  x.\mathrm{dtype}=\mathit{float64} \\
        5&  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases}
    \end{align}


  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:

  .. math::
    \begin{align}
     \mathrm{ndtr}(x) =&\  \mathit{scale} * (1 + \mathit{sum}) + R_N \\
     \mathit{scale}   =&\  \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
     \mathit{sum}     =&\  \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
     R_N     =&\  O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
    \end{align}

  where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
  `double-factorial
  <https://en.wikipedia.org/wiki/Double_factorial>`_ operator.


  Args:
    x: an array of type `float32`, `float64`.
    series_order: Positive Python integer. Maximum depth to
      evaluate the asymptotic expansion. This is the `N` above.

  Returns:
    an array with `dtype=x.dtype`.

  Raises:
    TypeError: if `x.dtype` is not handled.
    TypeError: if `series_order` is a not Python `integer.`
    ValueError:  if `series_order` is not in `[0, 30]`.
  """
    if not isinstance(series_order, int):
        raise TypeError("series_order must be a Python integer.")
    if series_order < 0:
        raise ValueError("series_order must be non-negative.")
    if series_order > 30:
        raise ValueError("series_order must be <= 30.")

    x = jnp.asarray(x)
    dtype = lax.dtype(x)

    if dtype == jnp.float64:
        lower_segment = _LOGNDTR_FLOAT64_LOWER
        upper_segment = _LOGNDTR_FLOAT64_UPPER
    elif dtype == jnp.float32:
        lower_segment = _LOGNDTR_FLOAT32_LOWER
        upper_segment = _LOGNDTR_FLOAT32_UPPER
    else:
        raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype)))

    # The basic idea here was ported from:
    #   https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    # We copy the main idea, with a few changes
    # * For x >> 1, and X ~ Normal(0, 1),
    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
    #     which extends the range of validity of this function.
    # * We use one fixed series_order for all of 'x', rather than adaptive.
    # * Our docstring properly reflects that this is an asymptotic series, not a
    #   Taylor series. We also provided a correct bound on the remainder.
    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
    #   x=0. This happens even though the branch is unchosen because when x=0
    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
    #   regardless of whether dy is finite. Note that the minimum is a NOP if
    #   the branch is chosen.
    return jnp.where(
        lax.gt(x, upper_segment),
        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
        jnp.where(lax.gt(x, lower_segment),
                  lax.log(_ndtr(lax.max(x, lower_segment))),
                  _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
コード例 #24
0
ファイル: special.py プロジェクト: GregCT/jax
def xlogy(x, y):
    x, y = _promote_args_inexact("xlogy", x, y)
    x_ok = x != 0.
    safe_x = jnp.where(x_ok, x, 1.)
    safe_y = jnp.where(x_ok, y, 1.)
    return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
コード例 #25
0
def xlogy_jvp_lhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlogy, x, y)
    g, y = _promote_args_like(osp_special.xlogy, g, y)
    return lax._safe_mul(lax._brcast(g, y), lax._brcast(lax.log(y), g))
コード例 #26
0
ファイル: jet.py プロジェクト: 0x0is1/jax

def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
コード例 #27
0
ファイル: special.py プロジェクト: GregCT/jax
def _ndtri(p):
    """Implements ndtri core logic."""

    # Constants used in piece-wise rational approximations. Taken from the cephes
    # library:
    # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    p0 = list(
        reversed([
            -5.99633501014107895267E1, 9.80010754185999661536E1,
            -5.66762857469070293439E1, 1.39312609387279679503E1,
            -1.23916583867381258016E0
        ]))
    q0 = list(
        reversed([
            1.0, 1.95448858338141759834E0, 4.67627912898881538453E0,
            8.63602421390890590575E1, -2.25462687854119370527E2,
            2.00260212380060660359E2, -8.20372256168333339912E1,
            1.59056225126211695515E1, -1.18331621121330003142E0
        ]))
    p1 = list(
        reversed([
            4.05544892305962419923E0, 3.15251094599893866154E1,
            5.71628192246421288162E1, 4.40805073893200834700E1,
            1.46849561928858024014E1, 2.18663306850790267539E0,
            -1.40256079171354495875E-1, -3.50424626827848203418E-2,
            -8.57456785154685413611E-4
        ]))
    q1 = list(
        reversed([
            1.0, 1.57799883256466749731E1, 4.53907635128879210584E1,
            4.13172038254672030440E1, 1.50425385692907503408E1,
            2.50464946208309415979E0, -1.42182922854787788574E-1,
            -3.80806407691578277194E-2, -9.33259480895457427372E-4
        ]))
    p2 = list(
        reversed([
            3.23774891776946035970E0, 6.91522889068984211695E0,
            3.93881025292474443415E0, 1.33303460815807542389E0,
            2.01485389549179081538E-1, 1.23716634817820021358E-2,
            3.01581553508235416007E-4, 2.65806974686737550832E-6,
            6.23974539184983293730E-9
        ]))
    q2 = list(
        reversed([
            1.0, 6.02427039364742014255E0, 3.67983563856160859403E0,
            1.37702099489081330271E0, 2.16236993594496635890E-1,
            1.34204006088543189037E-2, 3.28014464682127739104E-4,
            2.89247864745380683936E-6, 6.79019408009981274425E-9
        ]))

    dtype = lax.dtype(p).type
    shape = jnp.shape(p)

    def _create_polynomial(var, coeffs):
        """Compute n_th order polynomial via Horner's method."""
        coeffs = np.array(coeffs, dtype)
        if not coeffs.size:
            return jnp.zeros_like(var)
        return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var

    maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
    # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
    # later on. The result from the computation when p == 0 is not used so any
    # number that doesn't result in NaNs is fine.
    sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.),
                              jnp.full(shape, dtype(0.5)), maybe_complement_p)

    # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
    w = sanitized_mcp - dtype(0.5)
    ww = lax.square(w)
    x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) /
                                _create_polynomial(ww, q0))
    x_for_big_p *= -dtype(np.sqrt(2. * np.pi))

    # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
    # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
    # arrays based on whether p < exp(-32).
    z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
    first_term = z - lax.log(z) / z
    second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) /
                           _create_polynomial(dtype(1.) / z, q2) / z)
    second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) /
                             _create_polynomial(dtype(1.) / z, q1) / z)
    x_for_small_p = first_term - second_term_small_p
    x_otherwise = first_term - second_term_otherwise

    x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p,
                  jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))

    x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
    infinity = jnp.full(shape, dtype(np.inf))
    x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity,
                               jnp.where(p >= dtype(1.0), infinity, x))
    return x_nan_replaced
コード例 #28
0
def exp2(x):
    x, = _promote_args_inexact("exp2", x)
    return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
コード例 #29
0
ファイル: special.py プロジェクト: GregCT/jax
def logit(x):
    x = asarray(x)
    return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
コード例 #30
0
def log10(x):
    x, = _promote_args_inexact("log10", x)
    return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))