示例#1
0
def gelu(x: Array, approximate: bool = True) -> Array:
    r"""Gaussian error linear unit activation function.

  If ``approximate=False``, computes the element-wise function:

  .. math::
    \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
      \frac{x}{\sqrt{2}} \right) \right)

  If ``approximate=True``, uses the approximate formulation of GELU:

  .. math::
    \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
      \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)

  For more information, see `Gaussian Error Linear Units (GELUs)
  <https://arxiv.org/abs/1606.08415>`_, section 2.

  Args:
    x : input array
    approximate: whether to use the approximate or exact formulation.
  """
    if approximate:
        sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
        cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x**3))))
        return x * cdf
    else:
        return jnp.array(x * (lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)
示例#2
0
文件: special.py 项目: GregCT/jax
def _ndtr(x):
    """Implements ndtr core logic."""
    dtype = lax.dtype(x).type
    half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype)
    w = x * half_sqrt_2
    z = lax.abs(w)
    y = lax.select(
        lax.lt(z, half_sqrt_2),
        dtype(1.) + lax.erf(w),
        lax.select(lax.gt(w, dtype(0.)),
                   dtype(2.) - lax.erfc(z), lax.erfc(z)))
    return dtype(0.5) * y
示例#3
0
文件: special.py 项目: GregCT/jax
def erf(x):
    x, = _promote_args_inexact("erf", x)
    return lax.erf(x)
示例#4
0
文件: jet.py 项目: 0x0is1/jax
def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))
示例#5
0
文件: erf.py 项目: gglin001/onnx-jax
 def _erf(x):
     return lax.erf(x)
示例#6
0
文件: functions.py 项目: wade1990/jax
def gelu(x):
  """Gaussian error linear unit activation"""
  return x * (lax.erf(x / np.sqrt(2)) + 1) / 2