Пример #1
0
def logaddexp(x1, x2):
    x1, x2 = _promote_to_result_dtype(onp.logaddexp, *_promote_shapes(x1, x2))
    amax = lax.max(x1, x2)
    return lax.add(
        amax,
        lax.log(lax.add(lax.exp(lax.sub(x1, amax)), lax.exp(lax.sub(x2,
                                                                    amax)))))
Пример #2
0
def _logaddexp(x1, x2):
  """
  Logaddexp while ignoring the custom_jvp rule.
  """
  amax = lax.max(x1, x2)
  delta = lax.sub(x1, x2)
  return lax.select(jnp.isnan(delta),
                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
                    lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
Пример #3
0
 def lr_schedule(itr):
     """
     Learning rate schedule.
     Slowly warm-up with a small learning rate.
     """
     iter_frac = lax.min((itr.astype(jnp.float32) + 1.) / lax.max(parse_args.warmup_itrs, 1.), 1.)
     _epoch = itr // num_batches
     id = lambda x: x
     return lax.cond(_epoch < 80, parse_args.lr * iter_frac, id, parse_args.lr / 10, id)
Пример #4
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Пример #5
0
def _nonconvex_max(bound_cls, index, lhs, rhs):
    """Propagation of NonConvex bounds through a max.

  Args:
    bound_cls: Bound class to use.
    index: Index of the computation.
    lhs: First input to the max, assumed to be a ReLU input
    rhs: Second input to the max, assumed to be 0
  Returns:
    out_bounds: NonConvexBound
  """
    if not (isinstance(lhs, NonConvexBound) and rhs == 0.):
        raise NotImplementedError('Only ReLU is implemented for now.')

    new_concretizer = lhs.concretizer.propagate_concretizer(
        lax.max_p, index, lhs, rhs)

    # Get the input bounds necessary for the relaxation
    inp_lb = lhs.lower
    inp_ub = lhs.upper

    # Get the upper bound
    relu_on = (inp_lb >= 0.)
    relu_amb = jnp.logical_and(inp_lb < 0., inp_ub >= 0.)
    slope = relu_on.astype(jnp.float32)
    slope += jnp.where(relu_amb, inp_ub / jnp.maximum(inp_ub - inp_lb, 1e-12),
                       jnp.zeros_like(inp_lb))
    offset = jnp.where(relu_amb, -slope * inp_lb, jnp.zeros_like(inp_lb))
    broad_slope = jnp.expand_dims(slope, 1)
    broad_offset = jnp.expand_dims(offset, 1)

    # Define the lower and upper bound functions
    lb_fun = lambda x: lax.max(x, 0.)
    ub_fun = lambda x: broad_slope * x + broad_offset

    return _activation_convex_relaxation(bound_cls, index, lhs,
                                         new_concretizer, 'ReLU', lb_fun,
                                         ub_fun)
Пример #6
0
def log_ndtr(x, series_order=3):
    r"""Log Normal distribution function.

  For details of the Normal distribution function see `ndtr`.

  This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
  :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:

  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
    :math:`\log(1-x) \approx -x, x \ll 1`.
  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
    and take a log.
  - For `x <= lower_segment`, we use the series approximation of `erf` to compute
    the log CDF directly.

  The `lower_segment` is set based on the precision of the input:

  .. math::
    \begin{align}
    \mathit{lower\_segment} =&
      \ \begin{cases}
        -20 &  x.\mathrm{dtype}=\mathit{float64} \\
        -10 &  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases} \\
    \mathit{upper\_segment} =&
      \ \begin{cases}
        8&  x.\mathrm{dtype}=\mathit{float64} \\
        5&  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases}
    \end{align}


  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:

  .. math::
    \begin{align}
     \mathrm{ndtr}(x) =&\  \mathit{scale} * (1 + \mathit{sum}) + R_N \\
     \mathit{scale}   =&\  \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
     \mathit{sum}     =&\  \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
     R_N     =&\  O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
    \end{align}

  where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
  `double-factorial
  <https://en.wikipedia.org/wiki/Double_factorial>`_ operator.


  Args:
    x: an array of type `float32`, `float64`.
    series_order: Positive Python integer. Maximum depth to
      evaluate the asymptotic expansion. This is the `N` above.

  Returns:
    an array with `dtype=x.dtype`.

  Raises:
    TypeError: if `x.dtype` is not handled.
    TypeError: if `series_order` is a not Python `integer.`
    ValueError:  if `series_order` is not in `[0, 30]`.
  """
    if not isinstance(series_order, int):
        raise TypeError("series_order must be a Python integer.")
    if series_order < 0:
        raise ValueError("series_order must be non-negative.")
    if series_order > 30:
        raise ValueError("series_order must be <= 30.")

    x = jnp.asarray(x)
    dtype = lax.dtype(x)

    if dtype == jnp.float64:
        lower_segment = _LOGNDTR_FLOAT64_LOWER
        upper_segment = _LOGNDTR_FLOAT64_UPPER
    elif dtype == jnp.float32:
        lower_segment = _LOGNDTR_FLOAT32_LOWER
        upper_segment = _LOGNDTR_FLOAT32_UPPER
    else:
        raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype)))

    # The basic idea here was ported from:
    #   https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    # We copy the main idea, with a few changes
    # * For x >> 1, and X ~ Normal(0, 1),
    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
    #     which extends the range of validity of this function.
    # * We use one fixed series_order for all of 'x', rather than adaptive.
    # * Our docstring properly reflects that this is an asymptotic series, not a
    #   Taylor series. We also provided a correct bound on the remainder.
    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
    #   x=0. This happens even though the branch is unchosen because when x=0
    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
    #   regardless of whether dy is finite. Note that the minimum is a NOP if
    #   the branch is chosen.
    return jnp.where(
        lax.gt(x, upper_segment),
        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
        jnp.where(lax.gt(x, lower_segment),
                  lax.log(_ndtr(lax.max(x, lower_segment))),
                  _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
Пример #7
0
Файл: jet.py Проект: 0x0is1/jax
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities
Пример #8
0
 def _warmup_sched(itr):
     itr_frac = lax.min(
         (itr.astype(jnp.float32) + 1.) / lax.max(warmup, 1.), 1.)
     _epoch = itr // nb
     id = lambda x: x
     return lax.cond(_epoch < 55, lr * itr_frac, id, lr / 10, id)