Example #1
    def sum_first_n(arr, num):
      def body_fun(i, state):
        arr, total, _ = state
        arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
        return (arr, lax.add(total, arr_i), ())

      init_val = (arr, 0., ())
      _, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
      return tot
Example #2
    def sum_first_n(arr, num):
      def body_fun(i, state):
        arr, total = state['arr'], state['total']
        arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
        return {'arr': arr, 'total': lax.add(total, arr_i)}

      init_val = {'arr': arr, 'total': 0.}
      out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
      return out_val['total']
Example #3
 def lr_schedule(itr):
     Learning rate schedule.
     Slowly warm-up with a small learning rate.
     iter_frac = lax.min((itr.astype(jnp.float32) + 1.) / lax.max(parse_args.warmup_itrs, 1.), 1.)
     _epoch = itr // num_batches
     id = lambda x: x
     return lax.cond(_epoch < 80, parse_args.lr * iter_frac, id, parse_args.lr / 10, id)
Example #4
def log_ndtr(x, series_order=3):
    r"""Log Normal distribution function.

  For details of the Normal distribution function see `ndtr`.

  This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
  :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:

  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
    :math:`\log(1-x) \approx -x, x \ll 1`.
  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
    and take a log.
  - For `x <= lower_segment`, we use the series approximation of `erf` to compute
    the log CDF directly.

  The `lower_segment` is set based on the precision of the input:

  .. math::
    \mathit{lower\_segment} =&
      \ \begin{cases}
        -20 &  x.\mathrm{dtype}=\mathit{float64} \\
        -10 &  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases} \\
    \mathit{upper\_segment} =&
      \ \begin{cases}
        8&  x.\mathrm{dtype}=\mathit{float64} \\
        5&  x.\mathrm{dtype}=\mathit{float32} \\

  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:

  .. math::
     \mathrm{ndtr}(x) =&\  \mathit{scale} * (1 + \mathit{sum}) + R_N \\
     \mathit{scale}   =&\  \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
     \mathit{sum}     =&\  \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
     R_N     =&\  O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})

  where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
  <https://en.wikipedia.org/wiki/Double_factorial>`_ operator.

    x: an array of type `float32`, `float64`.
    series_order: Positive Python integer. Maximum depth to
      evaluate the asymptotic expansion. This is the `N` above.

    an array with `dtype=x.dtype`.

    TypeError: if `x.dtype` is not handled.
    TypeError: if `series_order` is a not Python `integer.`
    ValueError:  if `series_order` is not in `[0, 30]`.
    if not isinstance(series_order, int):
        raise TypeError("series_order must be a Python integer.")
    if series_order < 0:
        raise ValueError("series_order must be non-negative.")
    if series_order > 30:
        raise ValueError("series_order must be <= 30.")

    x = jnp.asarray(x)
    dtype = lax.dtype(x)

    if dtype == jnp.float64:
        lower_segment = _LOGNDTR_FLOAT64_LOWER
        upper_segment = _LOGNDTR_FLOAT64_UPPER
    elif dtype == jnp.float32:
        lower_segment = _LOGNDTR_FLOAT32_LOWER
        upper_segment = _LOGNDTR_FLOAT32_UPPER
        raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype)))

    # The basic idea here was ported from:
    #   https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    # We copy the main idea, with a few changes
    # * For x >> 1, and X ~ Normal(0, 1),
    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
    #     which extends the range of validity of this function.
    # * We use one fixed series_order for all of 'x', rather than adaptive.
    # * Our docstring properly reflects that this is an asymptotic series, not a
    #   Taylor series. We also provided a correct bound on the remainder.
    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
    #   x=0. This happens even though the branch is unchosen because when x=0
    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
    #   regardless of whether dy is finite. Note that the minimum is a NOP if
    #   the branch is chosen.
    return jnp.where(
        lax.gt(x, upper_segment),
        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
        jnp.where(lax.gt(x, lower_segment),
                  lax.log(_ndtr(lax.max(x, lower_segment))),
                  _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
Example #5
File: jet.py Project: 0x0is1/jax
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
    jet_rules[prim] = partial(jet, comp)

def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))

def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities
Example #6
def linear_warmup(warmup_iters):
    return lambda i: lax.min(1., i / warmup_iters)
Example #7
 def _warmup_sched(itr):
     itr_frac = lax.min(
         (itr.astype(jnp.float32) + 1.) / lax.max(warmup, 1.), 1.)
     _epoch = itr // nb
     id = lambda x: x
     return lax.cond(_epoch < 55, lr * itr_frac, id, lr / 10, id)