def _normalize_float(x): info = dtypes.finfo(dtypes.dtype(x)) cond = lax.abs(x) < info.tiny x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x) x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32)) int_type = _INT_DTYPES[info.bits] return lax.bitcast_convert_type(x1, int_type), x2
def _sinc_maclaurin(k, x): # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the monomial term in the jvp rule) if k % 2: return lax.full_like(x, 0) else: return lax.full_like(x, (-1) ** (k // 2) / (k + 1))
def _abs_taylor_rule(x, series_in, **params): x, = x zero = lax.full_like(x, 0, shape=()) primal_out = lax.abs_p.bind(x, **params) negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0)) fix_sign = lambda y: negs * y series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)] return primal_out, series_out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) if b is None: out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b), axis=dims, keepdims=keepdims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return out
def polyval(p, x, *, unroll=16): _check_arraylike("polyval", p, x) p, x = _promote_dtypes_inexact(p, x) shape = lax.broadcast_shapes(p.shape[1:], x.shape) y = lax.full_like(x, 0, shape=shape, dtype=x.dtype) y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) return y
def _lu_blocked(a, block_size=32): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) pivot = np.zeros((r, ), dtype=np.int32) error = np.array(False, np.bool_) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:, k:k + b]) error = error | block_error a = ops.index_update(a, ops.index[k:, k:k + b], lu_block) a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k]) pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k) if k + b < n: a = ops.index_update(a, ops.index[k:, k + b:], a[perm + k, k + b:]) a = ops.index_update( a, ops.index[k:k + b, k + b:], triangular_solve(a[k:k + b, k:k + b], a[k:k + b, k + b:], left_side=True, lower=True, unit_diagonal=True)) a = ops.index_add( a, ops.index[k + b:, k + b:], -lax.dot(a[k + b:, k:k + b], a[k:k + b, k + b:], precision=lax.Precision.HIGHEST)) a = np.where(error, lax.full_like(a, np.nan), a) return pivot, a
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) dims = _reduction_dims(a, axis) dimadd = lambda x: lax.expand_dims(x, dims) amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_singletons = dimadd(amax) if b is None: out = lax.add( lax.log( lax.reduce(lax.exp(lax.sub(a, amax_singletons)), _constant_like(a, 0), lax.add, dims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b), _constant_like(a, 0), lax.add, dims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (dimadd(out), dimadd(sign)) if keepdims else (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return dimadd(out) if keepdims else out
def _dynamic_index(x, idx): if not idx: return x ndim = len(x.shape) starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx)) sizes = (1, ) * len(idx) + x.shape[len(idx):] out = lax.dynamic_slice(x, starts, sizes) return out.reshape(x.shape[len(idx):])
def _checkresult(self, result, cond, bad_value): if cond.ndim != 0: result = np.where(cond, bad_value, result) elif cond: if result.ndim == 0: return bad_value result = lax.full_like(result, bad_value) return device_put(result)
def ldexp(x1, x2): _check_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError( f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, x2 = _promote_shapes("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype)) info = dtypes.finfo(dtype) int_type = _INT_DTYPES[info.bits] x1 = lax.convert_element_type(x1, dtype) x2 = lax.convert_element_type(x2, int_type) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = x2 < -(bias + info.nmant) overflow_cond = x2 > bias m = lax.full_like(x, 1, dtype=dtype) # denormals cond = x2 < -bias + 1 x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
def isfinite(x): _check_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) elif dtypes.issubdtype(dtype, np.complexfloating): return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) else: return lax.full_like(x, True, dtype=np.bool_)
def _isposneginf(infinity, x, out): if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(x, _constant_like(x, infinity)) elif dtypes.issubdtype(dtype, np.complexfloating): raise ValueError("isposinf/isneginf are not well defined for complex types") else: return lax.full_like(x, False, dtype=np.bool_)
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False): a = _asarray(a) if weights is None: # Treat all weights as 1 avg = mean(a, axis=axis) if axis is None: weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: weights = _asarray(weights) if dtypes.issubdtype(a.dtype, np.inexact): out_dtype = dtypes.result_type(a.dtype, weights.dtype) else: out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) out_dtype = dtypes.canonicalize_dtype(out_dtype) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, dtype=out_dtype) avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
def isinf(x): _check_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) elif dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) im = lax.imag(x) return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)), lax.eq(lax.abs(im), _constant_like(im, np.inf))) else: return lax.full_like(x, False, dtype=np.bool_)
def ldexp(x1, x2): _check_arraylike("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) x1, x2 = _promote_shapes("ldexp", x1, x2) x1 = lax.convert_element_type(x1, dtype) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 int_type = _INT_DTYPES[info.bits] x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = x2 < -(bias + info.nmant) overflow_cond = x2 > bias m = lax.full_like(x, 1, dtype=dtype) # denormals cond = x2 < -bias + 1 x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def floor_divide(x1, x2): x1, x2 = _promote_args("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) # TODO(mattjj): investigate why subtracting a scalar was causing promotion return _where(select, quotient - 1, quotient) elif dtypes.issubdtype(dtype, np.complexfloating): x1r = lax.real(x1) x1i = lax.imag(x1) x2r = lax.real(x2) x2i = lax.imag(x2) which = lax.ge(lax.abs(x2r), lax.abs(x2i)) rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) return lax.convert_element_type(out, dtype) else: return _float_divmod(x1, x2)[0]
def signbit(x): x, = _promote_args("signbit", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): return lax.full_like(x, False, dtype=np.bool_) elif not dtypes.issubdtype(dtype, np.floating): raise ValueError("jax.numpy.signbit is not well defined for %s" % dtype) # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to # F32. if dtype == dtypes.bfloat16: dtype = np.float32 x = lax.convert_element_type(x, np.float32) info = dtypes.finfo(dtype) if info.bits not in _INT_DTYPES: raise NotImplementedError( "jax.numpy.signbit only supports 16, 32, and 64-bit types.") int_type = _INT_DTYPES[info.bits] x = lax.bitcast_convert_type(x, int_type) return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def entr(x): x, = _promote_args_inexact("entr", x) return lax.select(lax.lt(x, _constant_like(x, 0)), lax.full_like(x, -np.inf), lax.neg(xlogy(x, x)))
def _dynamic_update_index(x, idx, val): if not idx: return val ndim = len(x.shape) starts = [*idx] + [lax.full_like(idx[0], 0, shape=())] * (ndim - len(idx)) update = val.reshape((1, ) * len(idx) + x.shape[len(idx):]) return lax.dynamic_update_slice(x, update, starts)
def sqrt(x): return LapaxMatrix(lax.pow(x.ndarray, lax.full_like(x.ndarray, 0.5)), x.bs)
def full_like(x, val): return LapaxMatrix(lax.full_like(x.ndarray, val), x.bs)
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context): """Normalizes attention.""" a = attn_weights def unquantized_softmax(a): a = lax.exp( a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True)) return a.astype(dtype) # Quantize intermediate activations with QuantOps. # Currently only supports unscaled floating-point formats. def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype) # If no params, return accurate Softmax. if softmax_hparams == SoftmaxHParams(None, None, None) or softmax_hparams is None: return unquantized_softmax(a) # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for # the entire training run, even before the global activation start step. if softmax_hparams.quant_hparams is not None: return lax.cond(quant_context.quantize_acts, quantized_softmax, unquantized_softmax, a) # Approximated Softmax exp_hparams = softmax_hparams.exp_hparams recip_hparams = softmax_hparams.reciprocal_hparams # Substract max value from dimensions to be normalized. shape = jax.util.subvals(onp.shape(a), zip(norm_dims, (1, ) * len(norm_dims))) dimadd = lambda x: lax.reshape(x, shape) # pylint: disable=protected-access amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max, norm_dims) amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)) amax_singletons = dimadd(amax) asubmax = lax.sub(a, amax_singletons) # Calculate approximated exponential approx_exp = exponential(asubmax, dtype, exp_hparams) # If sum_high_bound: Upper clip bound for sum(exp(x-M)). asumexp = dimadd( lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add, norm_dims)) if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0: sum_low_bound = 1. if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract: sum_low_bound = 1 - onp.exp(exp_hparams.low_bound) asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound) # Approximation of reciprocal. arecip = reciprocal(asumexp, dtype, recip_hparams) return lax.mul(approx_exp, arecip).astype(dtype)
def imag(val): _check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@jax.jit def relu(x: Array) -> Array: r"""Rectified linear unit activation function. Computes the element-wise function: .. math:: \mathrm{relu}(x) = \max(x, 0) Args: x : input array """ return jnp.maximum(x, 0) relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) @jax.jit def softplus(x: Array) -> Array: r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) Args: x : input array """ return jnp.logaddexp(x, 0)
def op(*args): zero = lambda x: lax.full_like(x, shape=(), fill_value=0) args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne( x, zero(x)) for x in args) return bitwise_op(*_promote_args(np_op.__name__, *args))
def _relu_jvp(primals, tangents): x, = primals t, = tangents return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))