def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) dims = _reduction_dims(a, axis) dimadd = lambda x: lax.expand_dims(x, dims) amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_singletons = dimadd(amax) if b is None: out = lax.add( lax.log( lax.reduce(lax.exp(lax.sub(a, amax_singletons)), _constant_like(a, 0), lax.add, dims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b), _constant_like(a, 0), lax.add, dims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (dimadd(out), dimadd(sign)) if keepdims else (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return dimadd(out) if keepdims else out
def testReducePairGrad(self, shape, dtype, dims): rng = jtu.rand_default(self.rng(), scale=1) tol = {np.float32: 1e-2, np.float64: 1e-4} operands = (rng(shape, dtype), rng(shape, dtype)) init_vals = (np.array(0, dtype), np.array(1, dtype)) def op(xs, ys): return (xs[0] + ys[0], xs[1] * ys[1]) reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims) check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
def testVariadicReduce(self, shape, dtype, dims, bdims): def op(a, b): x1, y1 = a x2, y2 = b return x1 + x2, y1 * y2 rng = jtu.rand_small(self.rng()) init_val = tuple(np.asarray([0, 1], dtype=dtype)) fun = lambda x, y: lax.reduce((x, y), init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng, multiple_results=True)
def quantized_sum( x, # axis, keepdims, prec): """Sums a tensor while quantizing intermediate accumulations. This is almost a drop-in replacement for jnp.sum. It only differs in that it takes in an 'act_hparams' parameter that controls the quantization of intermediate accumulations during the reduction. Arguments: x: Input, a Jax array axis: Which axes to reduce over (see jnp.sum docs) keepdims: Whether to keep of drop axes that are reduced (see jnp.sum docs) prec: Precision to quantize intermediate to. Currently can only an instance of QuantOps.FloatQuant.FloatPrec, corresponding to an unscaled floating-point format, or it can be None to indicate no quantization should be applied. Returns: A Jax array with the quantized sum of 'x'. """ # Don't quantize. In this case, this function just wraps jnp.sum. if prec is None: return jnp.sum(x, axis=axis, keepdims=keepdims) # We bypass QuantOps.create_input_ops and directly call # QuantOps.create_symmetric_fp because the former creates an instance of # GetBounds, which in turn creates state variables to store activation # statistics. We do not want to compute statistics for each individual # addition within the sum reduction. fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) if not isinstance(axis, Iterable): axis = (axis, ) axis = utils.normalize_axes(axis, x.ndim) dtype = x.dtype zero = jnp.zeros((), dtype=dtype) x_quantized_sum = lax.reduce( x, init_values=zero, computation=lambda a, b: quant_ops.to_quantized(a + b, dtype=dtype), dimensions=axis) if keepdims: x_quantized_sum = jnp.expand_dims(x_quantized_sum, axis) return x_quantized_sum
def _variadic_reduce_no_grad(operands, inits, axis, reducer): if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top return lax.reduce( operands, init_values=inits, dimensions=axis, computation=reducer) elif (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): return _variadic_reduce( operands, init=inits, axis=axis, reducer=reducer) else: return _xla_reduce(operands, inits, axis)
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, onp.float64: 1e-3, onp.complex64: 1e-1} operand = rng(shape, dtype) init_val = onp.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
def reduction(a, axis=None, dtype=None, out=None, keepdims=False): if out is not None: raise ValueError("reduction does not support `out` argument.") a = a if isinstance(a, ndarray) else asarray(a) dims = _reduction_dims(a, axis) result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a)))) if _dtype(a) != result_dtype: a = lax.convert_element_type(a, result_dtype) result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims) if keepdims: shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims))) result = lax.reshape(result, shape_with_singletons) if dtype and onp.dtype(dtype) != onp.dtype(result_dtype): result = lax.convert_element_type(result, dtype) return result
def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
def _reduction(a, name, np_fun, op, init_val, has_identity=True, preproc=None, bool_op=None, upcast_f16_for_computation=False, axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None, parallel_reduce=None): bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method # exists, passing along all its arguments. if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{name} is not supported.") _check_arraylike(name, a) lax_internal._check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") if initial is None and not has_identity and where_ is not None: raise ValueError( f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") a = a if isinstance(a, ndarray) else _asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) if initial is None and not has_identity: shape = np.shape(a) if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims): raise ValueError( f"zero-size array to reduction operation {name} which has no identity" ) result_dtype = dtypes.canonicalize_dtype( dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a))))) if upcast_f16_for_computation and dtypes.issubdtype( result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) else: computation_dtype = result_dtype a = lax.convert_element_type(a, computation_dtype) op = op if computation_dtype != np.bool_ else bool_op # NB: in XLA, init_val must be an identity for the op, so the user-specified # initial value must be applied afterward. init_val = _reduction_init_val(a, init_val) if where_ is not None: a = _where(where_, a, init_val) if pos_dims is not dims: if parallel_reduce is None: raise NotImplementedError( f"Named reductions not implemented for jnp.{name}()") result = parallel_reduce(a, dims) else: result = lax.reduce(a, init_val, op, dims) if initial is not None: result = op(lax.convert_element_type(initial, a.dtype), result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype)
def fun(x): # lax.reduce is unlikely to ever be convertible with enable_xla=False return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context): """Normalizes attention.""" a = attn_weights def unquantized_softmax(a): a = lax.exp( a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True)) return a.astype(dtype) # Quantize intermediate activations with QuantOps. # Currently only supports unscaled floating-point formats. def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype) # If no params, return accurate Softmax. if softmax_hparams == SoftmaxHParams(None, None, None) or softmax_hparams is None: return unquantized_softmax(a) # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for # the entire training run, even before the global activation start step. if softmax_hparams.quant_hparams is not None: return lax.cond(quant_context.quantize_acts, quantized_softmax, unquantized_softmax, a) # Approximated Softmax exp_hparams = softmax_hparams.exp_hparams recip_hparams = softmax_hparams.reciprocal_hparams # Substract max value from dimensions to be normalized. shape = jax.util.subvals(onp.shape(a), zip(norm_dims, (1, ) * len(norm_dims))) dimadd = lambda x: lax.reshape(x, shape) # pylint: disable=protected-access amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max, norm_dims) amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)) amax_singletons = dimadd(amax) asubmax = lax.sub(a, amax_singletons) # Calculate approximated exponential approx_exp = exponential(asubmax, dtype, exp_hparams) # If sum_high_bound: Upper clip bound for sum(exp(x-M)). asumexp = dimadd( lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add, norm_dims)) if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0: sum_low_bound = 1. if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract: sum_low_bound = 1 - onp.exp(exp_hparams.low_bound) asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound) # Approximation of reciprocal. arecip = reciprocal(asumexp, dtype, recip_hparams) return lax.mul(approx_exp, arecip).astype(dtype)
def reduce_fn(operands, inits, axis=None, keepdims=False): """Applies `reducer` to the given operands along the given axes. Args: operands: tuple of tensors, all having the same shape. inits: tuple of scalar tensors, with dtypes aligned to those of operands. axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of `int`. `None` is taken to mean "reduce all axes". keepdims: When `True`, we do not squeeze away the reduced dims, instead returning values with singleton dims in those axes. Returns: reduced: A tuple of the reduced operands. """ # Static shape consistency checks. args_shape = operands[0].shape for arg in operands[1:]: args_shape = tensorshape_util.merge_with(args_shape, arg.shape) ndims = tensorshape_util.rank(args_shape) if ndims is None: raise ValueError( 'Rank of at least one of `operands` must be known statically.') # Ensure the 'axis' arg is a tuple of non-negative ints. axis = np.arange(ndims) if axis is None else np.array(axis) if axis.ndim > 1: raise ValueError( '`axis` must be `None`, an `int`, or a sequence of ' '`int`, but got {}'.format(axis)) axis = np.reshape(axis, [-1]) axis = np.where(axis < 0, axis + ndims, axis) axis = tuple(int(ax) for ax in axis) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top result = lax.reduce(operands, init_values=inits, dimensions=axis, computation=reducer) elif (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): result = _variadic_reduce(operands, init=inits, axis=axis, reducer=reducer) else: result = _xla_reduce(operands, inits, axis) if keepdims: axis_nhot = ps.reduce_sum(ps.one_hot(axis, depth=ndims, on_value=True, off_value=False, dtype=tf.bool), axis=0) in_shape = args_shape if not tensorshape_util.is_fully_defined(in_shape): in_shape = tf.shape(operands[0]) final_shape = ps.where(axis_nhot, 1, in_shape) result = tf.nest.map_structure( lambda t: tf.reshape(t, final_shape), result) return result