Example #1
0
def logaddexp(x1, x2):
    x1, x2 = _promote_to_result_dtype(onp.logaddexp, *_promote_shapes(x1, x2))
    amax = lax.max(x1, x2)
    return lax.add(
        amax,
        lax.log(lax.add(lax.exp(lax.sub(x1, amax)), lax.exp(lax.sub(x2,
                                                                    amax)))))
Example #2
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Example #3
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Example #4
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _constant_like(x, 0.5)
    one = _constant_like(x, 1)
    zero = _constant_like(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Example #5
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Example #6
0
def exponential(tensor, dtype, exp_hparams):
    """Calculates an exponential approximation based on exp hyper params."""
    # If low_bound defined, it clips x-M.
    if exp_hparams.low_bound != 0:
        tensor = jnp.clip(tensor, exp_hparams.low_bound, 0.)

    # TODO(luispazos) Use standard calls to top level jnp functions.
    # pylint: disable=protected-access
    def make_constant(c):
        return lax_numpy._constant_like(tensor, c).astype(dtype)

    # If clip_and_subtract, replace exp(clip(x-M,low_bound)) term with
    # exp(clip(x-M,low_bound))-exp(low_bound).'
    if exp_hparams.clip_and_subtract:
        tensor = lax.sub(tensor, make_constant(onp.exp(exp_hparams.low_bound)))
    # If linear_gradient: use this gradient as linear approximation of
    # exponential.
    if exp_hparams.linear_gradient is not None and exp_hparams.linear_gradient != 0:
        # Want: max(0, a*x+b) such that a*x+b goes through (0, 1).
        #
        # This comes out to: max(0, a*x+1), for arbitrary a>0.
        one = jnp.full(tensor.shape, 1.).astype(dtype)
        gradient = jnp.full(tensor.shape,
                            exp_hparams.linear_gradient).astype(dtype)
        approx_exp = jnp.clip(lax.add(lax.mul(tensor, gradient), one), 0, 1)

    else:
        approx_exp = lax.exp(tensor)

    return approx_exp
Example #7
0
def log1m_exp(val):
    """Numerically stable implementation of `log(1 - exp(val))`."""
    return lax.cond(
        lax.gt(val, lax.log(2.0)),
        lambda _: lax.log(-lax.expm1(val)),
        lambda _: lax.log1p(-lax.exp(val)),
        operand=None,
    )
Example #8
0
def _logaddexp(x1, x2):
  """
  Logaddexp while ignoring the custom_jvp rule.
  """
  amax = lax.max(x1, x2)
  delta = lax.sub(x1, x2)
  return lax.select(jnp.isnan(delta),
                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
                    lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
Example #9
0
def _exp_taylor(primals_in, series_in):
  x, = primals_in
  series, = series_in
  u = [x] + series
  v = [lax.exp(x)] + [None] * len(series)
  for k in range(1,len(v)):
    v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)])
  primal_out, *series_out = v
  return primal_out, series_out
Example #10
0
File: jet.py Project: 0x0is1/jax
def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

    c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
    tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
    tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
    for k in range(1, len(series)):
        # we know c[:k], we compute c[k]

        # propagate c to get v
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

        # propagate v to get next c

        # square
        tmp_sq[k] = fact(k) * sum(
            _scale2(k, j) * v[k - j] * v[j] for j in range(k + 1))

        # exp
        tmp_exp[k] = fact(k - 1) * sum(
            _scale(k, j) * tmp_exp[k - j] * tmp_sq[j] for j in range(1, k + 1))

        # const
        c[k] = deriv_const * tmp_exp[k]

    # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
    k = len(series)
    v[k] = fact(k - 1) * sum(
        _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

    primal_out, *series_out = v
    return primal_out, series_out
Example #11
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out
Example #12
0
File: poisson.py Project: lmmx/mcx
def _random_poisson(rng_key, lmbda, shape):
    """
    References
    ----------
    .. [1] Knuth, Donald E. Art of computer programming, volume 2:
           Seminumerical algorithms. Addison-Wesley Professional, 2014 (p 137).
    """
    L = lax.exp(lax.neg(lmbda))
    k = np.zeros(shape=shape)
    p = np.ones(shape=shape)

    is_done = p < L
    while not is_done.all():
        _, rng_key = random.split(rng_key)
        u = random.uniform(rng_key, shape=shape)
        p = np.where(is_done, p, u * p)
        k = np.where(is_done, k, k + 1)
        is_done = p < L

    return k
Example #13
0
File: jet.py Project: 0x0is1/jax
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
Example #14
0
def pdf(x, mean, cov):
    return lax.exp(logpdf(x, mean, cov))
Example #15
0
def dot_product_attention(query,
                          key,
                          value,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs. This version is modified to
  move the softmax division after the dot product.


  Args:
    query: queries for calculating attention with shape of `[batch_size, dim1,
      dim2, ..., dimN, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size, dim1, dim2,
      ..., dimN, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size, dim1,
      dim2,..., dimN, num_heads, value_channels]`.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`.
  """
    assert key.shape[:-1] == value.shape[:-1]
    assert (query.shape[0:1] == key.shape[0:1]
            and query.shape[-1] == key.shape[-1])

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim
    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(np.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    query = query.transpose(qk_perm)
    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)

    query = query / jnp.sqrt(depth).astype(dtype)
    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = lax.dot_general(query,
                                   key, (((n - 1, ), (n - 1, )),
                                         (batch_dims_t, batch_dims_t)),
                                   precision=precision)

    # apply attention bias: masking, droput, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    decoding = attn_weights.shape[-2] != 256
    if decoding:
        attn_weights = lax.exp(attn_weights - jax.scipy.special.logsumexp(
            attn_weights, axis=norm_dims, keepdims=True))
    else:
        # move the division by the softmax denominator to after the dot product
        attn_weights = jnp.exp(attn_weights - lax.stop_gradient(
            jnp.max(attn_weights, axis=norm_dims, keepdims=True)))
        softmax_denominator = jnp.sum(attn_weights,
                                      axis=norm_dims,
                                      keepdims=False)
    attn_weights = attn_weights.astype(dtype)

    # apply dropout
    if not deterministic and dropout_rate > 0.:
        if dropout_rng is None:
            dropout_rng = nn.make_rng()
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = lax.dot_general(attn_weights,
                        value,
                        (wv_contracting_dims, (batch_dims_t, batch_dims_t)),
                        precision=precision)
    if not decoding:
        # divide by the denominator of the attention softmax now, when the array is
        # O(N*H) rather than O(N^2)
        y = y / jnp.expand_dims(softmax_denominator, -1)

    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    return y
Example #16
0
def i1(x):
    x, = _promote_args_inexact("i1", x)
    return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
Example #17
0
def pdf(x, b, loc=0, scale=1):
    return lax.exp(logpdf(x, b, loc, scale))
Example #18
0
def pdf(x, p):
    return lax.exp(logpdf(x, p))
Example #19
0
def cosh(x):
    x, = _promote_to_result_dtype(onp.cosh, x)
    return lax.div(lax.add(lax.exp(x), lax.exp(lax.neg(x))),
                   _constant_like(x, 2))
Example #20
0
def sinh(x):
    x, = _promote_to_result_dtype(onp.sinh, x)
    return lax.div(lax.sub(lax.exp(x), lax.exp(lax.neg(x))),
                   _constant_like(x, 2))
Example #21
0
 def _exp(x):
     return lax.exp(x)
Example #22
0
 def unquantized_softmax(a):
     a = lax.exp(
         a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True))
     return a.astype(dtype)
Example #23
0
File: jet.py Project: 0x0is1/jax
def fact(n):
    return lax.exp(lax.lgamma(n + 1.))
Example #24
0
def exp2(x):
    x, = _promote_args_inexact("exp2", x)
    return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
Example #25
0
def expit(x):
    x, = _promote_args_inexact("expit", x)
    one = _lax_const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Example #26
0
def pmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.pmf."""
    return lax.exp(logpmf(k, n, a, b, loc))
Example #27
0
def _log_ndtr_jvp(series_order, primals, tangents):
    (x, ), (t, ) = primals, tangents
    ans = log_ndtr(x, series_order=series_order)
    t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans)))
    return ans, t_out
Example #28
0
def pdf(x, alpha):
  return lax.exp(logpdf(x, alpha))
Example #29
0
def expit(x):
    x = asarray(x)
    one = lax._const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Example #30
0
def pdf(x):
    return lax.exp(logpdf(x))