コード例 #1
0
def _normalize_float(x):
  info = dtypes.finfo(dtypes.dtype(x))
  cond = lax.abs(x) < info.tiny
  x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
  x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32))
  int_type = _INT_DTYPES[info.bits]
  return lax.bitcast_convert_type(x1, int_type), x2
コード例 #2
0
def _sinc_maclaurin(k, x):
  # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
  # compute the monomial term in the jvp rule)
  if k % 2:
    return lax.full_like(x, 0)
  else:
    return lax.full_like(x, (-1) ** (k // 2) / (k + 1))
コード例 #3
0
ファイル: jet.py プロジェクト: MichaelMarien/jax
def _abs_taylor_rule(x, series_in, **params):
  x, = x
  zero = lax.full_like(x, 0, shape=())
  primal_out = lax.abs_p.bind(x, **params)
  negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
  fix_sign = lambda y: negs * y
  series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
  return primal_out, series_out
コード例 #4
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
コード例 #5
0
ファイル: polynomial.py プロジェクト: frederikwilde/jax
def polyval(p, x, *, unroll=16):
    _check_arraylike("polyval", p, x)
    p, x = _promote_dtypes_inexact(p, x)
    shape = lax.broadcast_shapes(p.shape[1:], x.shape)
    y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
    y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
    return y
コード例 #6
0
ファイル: lax_linalg.py プロジェクト: tpanthera/jax
def _lu_blocked(a, block_size=32):
    """Blocked LU decomposition, as an unrolled loop."""
    m, n = a.shape
    r = min(m, n)
    pivot = np.zeros((r, ), dtype=np.int32)
    error = np.array(False, np.bool_)
    for k in range(0, r, block_size):
        b = min(r - k, block_size)
        block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:,
                                                                   k:k + b])
        error = error | block_error
        a = ops.index_update(a, ops.index[k:, k:k + b], lu_block)

        a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k])
        pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k)

        if k + b < n:
            a = ops.index_update(a, ops.index[k:, k + b:], a[perm + k, k + b:])
            a = ops.index_update(
                a, ops.index[k:k + b, k + b:],
                triangular_solve(a[k:k + b, k:k + b],
                                 a[k:k + b, k + b:],
                                 left_side=True,
                                 lower=True,
                                 unit_diagonal=True))
            a = ops.index_add(
                a, ops.index[k + b:, k + b:],
                -lax.dot(a[k + b:, k:k + b],
                         a[k:k + b, k + b:],
                         precision=lax.Precision.HIGHEST))
    a = np.where(error, lax.full_like(a, np.nan), a)
    return pivot, a
コード例 #7
0
ファイル: special.py プロジェクト: xeransis/jax
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
コード例 #8
0
def _dynamic_index(x, idx):
    if not idx: return x
    ndim = len(x.shape)
    starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
    sizes = (1, ) * len(idx) + x.shape[len(idx):]
    out = lax.dynamic_slice(x, starts, sizes)
    return out.reshape(x.shape[len(idx):])
コード例 #9
0
ファイル: discrete.py プロジェクト: ColCarroll/numpyro
 def _checkresult(self, result, cond, bad_value):
     if cond.ndim != 0:
         result = np.where(cond, bad_value, result)
     elif cond:
         if result.ndim == 0:
             return bad_value
         result = lax.full_like(result, bad_value)
     return device_put(result)
コード例 #10
0
def ldexp(x1, x2):
    _check_arraylike("ldexp", x1, x2)
    x1_dtype = dtypes.dtype(x1)
    x2_dtype = dtypes.dtype(x2)
    if (dtypes.issubdtype(x1_dtype, np.complexfloating)
            or dtypes.issubdtype(x2_dtype, np.inexact)):
        raise ValueError(
            f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

    x1, x2 = _promote_shapes("ldexp", x1, x2)

    dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
    info = dtypes.finfo(dtype)
    int_type = _INT_DTYPES[info.bits]

    x1 = lax.convert_element_type(x1, dtype)
    x2 = lax.convert_element_type(x2, int_type)

    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1
    x, e = _normalize_float(x1)
    x2 += e + ((x >> info.nmant) & mask) - bias

    # find underflow/overflow before denormalization
    underflow_cond = x2 < -(bias + info.nmant)
    overflow_cond = x2 > bias

    m = lax.full_like(x, 1, dtype=dtype)

    # denormals
    cond = x2 < -bias + 1
    x2 = _where(cond, x2 + info.nmant, x2)
    m = _where(cond, m / (1 << info.nmant), m)

    x2 = lax.convert_element_type(x2, np.int32)
    x &= ~(mask << info.nmant)
    x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

    x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

    # underflow
    x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
    # overflow
    x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
    # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
    return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
コード例 #11
0
def isfinite(x):
    _check_arraylike("isfinite", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.floating):
        return lax.is_finite(x)
    elif dtypes.issubdtype(dtype, np.complexfloating):
        return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
    else:
        return lax.full_like(x, True, dtype=np.bool_)
コード例 #12
0
def _isposneginf(infinity, x, out):
  if out is not None:
    raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
  dtype = dtypes.dtype(x)
  if dtypes.issubdtype(dtype, np.floating):
    return lax.eq(x, _constant_like(x, infinity))
  elif dtypes.issubdtype(dtype, np.complexfloating):
    raise ValueError("isposinf/isneginf are not well defined for complex types")
  else:
    return lax.full_like(x, False, dtype=np.bool_)
コード例 #13
0
ファイル: reductions.py プロジェクト: cloudhan/jax
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
コード例 #14
0
def isinf(x):
    _check_arraylike("isinf", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.floating):
        return lax.eq(lax.abs(x), _constant_like(x, np.inf))
    elif dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        im = lax.imag(x)
        return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
                              lax.eq(lax.abs(im), _constant_like(im, np.inf)))
    else:
        return lax.full_like(x, False, dtype=np.bool_)
コード例 #15
0
def ldexp(x1, x2):
  _check_arraylike("ldexp", x1, x2)
  dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
  x1, x2 = _promote_shapes("ldexp", x1, x2)
  x1 = lax.convert_element_type(x1, dtype)

  info = dtypes.finfo(dtype)
  mask = (1 << info.nexp) - 1
  bias = ((1 << info.nexp) - 1) >> 1

  int_type = _INT_DTYPES[info.bits]

  x, e = _normalize_float(x1)
  x2 += e + ((x >> info.nmant) & mask) - bias

  # find underflow/overflow before denormalization
  underflow_cond = x2 < -(bias + info.nmant)
  overflow_cond = x2 > bias

  m = lax.full_like(x, 1, dtype=dtype)

  # denormals
  cond = x2 < -bias + 1
  x2 = _where(cond, x2 + info.nmant, x2)
  m = _where(cond, m / (1 << info.nmant), m)

  x2 = lax.convert_element_type(x2, np.int32)
  x &= ~(mask << info.nmant)
  x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

  x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

  # underflow
  x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
  # overflow
  x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
  # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
  return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
コード例 #16
0
ファイル: special.py プロジェクト: jbampton/jax
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out
コード例 #17
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
コード例 #18
0
def signbit(x):
    x, = _promote_args("signbit", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.lt(x, _constant_like(x, 0))
    elif dtypes.issubdtype(dtype, np.bool_):
        return lax.full_like(x, False, dtype=np.bool_)
    elif not dtypes.issubdtype(dtype, np.floating):
        raise ValueError("jax.numpy.signbit is not well defined for %s" %
                         dtype)

    # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
    # F32.
    if dtype == dtypes.bfloat16:
        dtype = np.float32
        x = lax.convert_element_type(x, np.float32)

    info = dtypes.finfo(dtype)
    if info.bits not in _INT_DTYPES:
        raise NotImplementedError(
            "jax.numpy.signbit only supports 16, 32, and 64-bit types.")
    int_type = _INT_DTYPES[info.bits]
    x = lax.bitcast_convert_type(x, int_type)
    return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
コード例 #19
0
ファイル: special.py プロジェクト: GregCT/jax
def entr(x):
    x, = _promote_args_inexact("entr", x)
    return lax.select(lax.lt(x, _constant_like(x, 0)),
                      lax.full_like(x, -np.inf), lax.neg(xlogy(x, x)))
コード例 #20
0
def _dynamic_update_index(x, idx, val):
    if not idx: return val
    ndim = len(x.shape)
    starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx))
    update = val.reshape((1, ) * len(idx) + x.shape[len(idx):])
    return lax.dynamic_update_slice(x, update, starts)
コード例 #21
0
ファイル: lapax.py プロジェクト: zudehuang/jax
def sqrt(x):
    return LapaxMatrix(lax.pow(x.ndarray, lax.full_like(x.ndarray, 0.5)), x.bs)
コード例 #22
0
ファイル: lapax.py プロジェクト: zudehuang/jax
def full_like(x, val):
    return LapaxMatrix(lax.full_like(x.ndarray, val), x.bs)
コード例 #23
0
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context):
    """Normalizes attention."""
    a = attn_weights

    def unquantized_softmax(a):
        a = lax.exp(
            a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True))
        return a.astype(dtype)

    # Quantize intermediate activations with QuantOps.
    # Currently only supports unscaled floating-point formats.
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)

    # If no params, return accurate Softmax.
    if softmax_hparams == SoftmaxHParams(None, None,
                                         None) or softmax_hparams is None:
        return unquantized_softmax(a)

    # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for
    # the entire training run, even before the global activation start step.
    if softmax_hparams.quant_hparams is not None:
        return lax.cond(quant_context.quantize_acts, quantized_softmax,
                        unquantized_softmax, a)

    # Approximated Softmax
    exp_hparams = softmax_hparams.exp_hparams
    recip_hparams = softmax_hparams.reciprocal_hparams

    # Substract max value from dimensions to be normalized.
    shape = jax.util.subvals(onp.shape(a),
                             zip(norm_dims, (1, ) * len(norm_dims)))
    dimadd = lambda x: lax.reshape(x, shape)
    # pylint: disable=protected-access
    amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max,
                      norm_dims)
    amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
    amax_singletons = dimadd(amax)
    asubmax = lax.sub(a, amax_singletons)

    # Calculate approximated exponential
    approx_exp = exponential(asubmax, dtype, exp_hparams)

    # If sum_high_bound: Upper clip bound for sum(exp(x-M)).
    asumexp = dimadd(
        lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add,
                   norm_dims))

    if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0:
        sum_low_bound = 1.
        if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract:
            sum_low_bound = 1 - onp.exp(exp_hparams.low_bound)
        asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound)

    # Approximation of reciprocal.
    arecip = reciprocal(asumexp, dtype, recip_hparams)
    return lax.mul(approx_exp, arecip).astype(dtype)
コード例 #24
0
def imag(val):
    _check_arraylike("imag", val)
    return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
コード例 #25
0
@jax.jit
def relu(x: Array) -> Array:
    r"""Rectified linear unit activation function.

  Computes the element-wise function:

  .. math::
    \mathrm{relu}(x) = \max(x, 0)

  Args:
    x : input array
  """
    return jnp.maximum(x, 0)


relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))


@jax.jit
def softplus(x: Array) -> Array:
    r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)

  Args:
    x : input array
  """
    return jnp.logaddexp(x, 0)
コード例 #26
0
 def op(*args):
     zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
     args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(
         x, zero(x)) for x in args)
     return bitwise_op(*_promote_args(np_op.__name__, *args))
コード例 #27
0
def _relu_jvp(primals, tangents):
    x, = primals
    t, = tangents
    return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))