Ejemplo n.º 1
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Ejemplo n.º 2
0
 def testReducePairGrad(self, shape, dtype, dims):
   rng = jtu.rand_default(self.rng(), scale=1)
   tol = {np.float32: 1e-2, np.float64: 1e-4}
   operands = (rng(shape, dtype), rng(shape, dtype))
   init_vals = (np.array(0, dtype), np.array(1, dtype))
   def op(xs, ys):
     return (xs[0] + ys[0], xs[1] * ys[1])
   reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
   check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
Ejemplo n.º 3
0
 def testVariadicReduce(self, shape, dtype, dims, bdims):
   def op(a, b):
     x1, y1 = a
     x2, y2 = b
     return x1 + x2, y1 * y2
   rng = jtu.rand_small(self.rng())
   init_val = tuple(np.asarray([0, 1], dtype=dtype))
   fun = lambda x, y: lax.reduce((x, y), init_val, op, dims)
   self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng,
                       multiple_results=True)
Ejemplo n.º 4
0
def quantized_sum(
        x,  #
        axis,
        keepdims,
        prec):
    """Sums a tensor while quantizing intermediate accumulations.

  This is almost a drop-in replacement for jnp.sum. It only differs in that it
  takes in an 'act_hparams' parameter that controls the quantization of
  intermediate accumulations during the reduction.

  Arguments:
    x: Input, a Jax array
    axis: Which axes to reduce over (see jnp.sum docs)
    keepdims: Whether to keep of drop axes that are reduced (see jnp.sum docs)
    prec: Precision to quantize intermediate to. Currently can only an instance
      of QuantOps.FloatQuant.FloatPrec, corresponding to an unscaled
      floating-point format, or it can be None to indicate no quantization
      should be applied.

  Returns:
    A Jax array with the quantized sum of 'x'.

  """

    # Don't quantize. In this case, this function just wraps jnp.sum.
    if prec is None:
        return jnp.sum(x, axis=axis, keepdims=keepdims)

    # We bypass QuantOps.create_input_ops and directly call
    # QuantOps.create_symmetric_fp because the former creates an instance of
    # GetBounds, which in turn creates state variables to store activation
    # statistics. We do not want to compute statistics for each individual
    # addition within the sum reduction.
    fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
    quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None)

    if not isinstance(axis, Iterable):
        axis = (axis, )
    axis = utils.normalize_axes(axis, x.ndim)
    dtype = x.dtype

    zero = jnp.zeros((), dtype=dtype)
    x_quantized_sum = lax.reduce(
        x,
        init_values=zero,
        computation=lambda a, b: quant_ops.to_quantized(a + b, dtype=dtype),
        dimensions=axis)

    if keepdims:
        x_quantized_sum = jnp.expand_dims(x_quantized_sum, axis)

    return x_quantized_sum
 def _variadic_reduce_no_grad(operands, inits, axis, reducer):
   if JAX_MODE:
     from jax import lax  # pylint: disable=g-import-not-at-top
     return lax.reduce(
         operands, init_values=inits, dimensions=axis, computation=reducer)
   elif (tf.executing_eagerly() or
         not control_flow_util.GraphOrParentsInXlaContext(
             tf1.get_default_graph())):
     return _variadic_reduce(
         operands, init=inits, axis=axis, reducer=reducer)
   else:
     return _xla_reduce(operands, inits, axis)
Ejemplo n.º 6
0
 def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
   rng = rng_factory(self.rng())
   if jtu.device_under_test() == "tpu" and op is lax.mul:
     raise SkipTest("unimplemented case")
   tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
          onp.float64: 1e-3, onp.complex64: 1e-1}
   operand = rng(shape, dtype)
   init_val = onp.asarray(init_val, dtype=dtype)
   reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
   eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
          1e-1 if dtype == dtypes.bfloat16 else
          1e-2 if dtypes.finfo(dtype).bits == 32 else None)
   check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
Ejemplo n.º 7
0
  def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
    if out is not None:
      raise ValueError("reduction does not support `out` argument.")

    a = a if isinstance(a, ndarray) else asarray(a)
    dims = _reduction_dims(a, axis)
    result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
    if _dtype(a) != result_dtype:
      a = lax.convert_element_type(a, result_dtype)
    result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims)
    if keepdims:
      shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims)))
      result = lax.reshape(result, shape_with_singletons)
    if dtype and onp.dtype(dtype) != onp.dtype(result_dtype):
      result = lax.convert_element_type(result, dtype)
    return result
Ejemplo n.º 8
0
 def testReduce(self, op, init_val, shape, dtype, dims, bdims):
     rng = jtu.rand_small(self.rng())
     init_val = np.asarray(init_val, dtype=dtype)
     fun = lambda operand: lax.reduce(operand, init_val, op, dims)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
Ejemplo n.º 9
0
def _reduction(a,
               name,
               np_fun,
               op,
               init_val,
               has_identity=True,
               preproc=None,
               bool_op=None,
               upcast_f16_for_computation=False,
               axis=None,
               dtype=None,
               out=None,
               keepdims=False,
               initial=None,
               where_=None,
               parallel_reduce=None):
    bool_op = bool_op or op
    # Note: we must accept out=None as an argument, because numpy reductions delegate to
    # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
    # exists, passing along all its arguments.
    if out is not None:
        raise NotImplementedError(
            f"The 'out' argument to jnp.{name} is not supported.")
    _check_arraylike(name, a)
    lax_internal._check_user_dtype_supported(dtype, name)
    axis = core.concrete_or_error(None, axis,
                                  f"axis argument to jnp.{name}().")

    if initial is None and not has_identity and where_ is not None:
        raise ValueError(
            f"reduction operation {name} does not have an identity, so to use a "
            f"where mask one has to specify 'initial'")

    a = a if isinstance(a, ndarray) else _asarray(a)
    a = preproc(a) if preproc else a
    pos_dims, dims = _reduction_dims(a, axis)

    if initial is None and not has_identity:
        shape = np.shape(a)
        if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
            raise ValueError(
                f"zero-size array to reduction operation {name} which has no identity"
            )

    result_dtype = dtypes.canonicalize_dtype(
        dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
    if upcast_f16_for_computation and dtypes.issubdtype(
            result_dtype, np.inexact):
        computation_dtype = _upcast_f16(result_dtype)
    else:
        computation_dtype = result_dtype
    a = lax.convert_element_type(a, computation_dtype)
    op = op if computation_dtype != np.bool_ else bool_op
    # NB: in XLA, init_val must be an identity for the op, so the user-specified
    # initial value must be applied afterward.
    init_val = _reduction_init_val(a, init_val)
    if where_ is not None:
        a = _where(where_, a, init_val)
    if pos_dims is not dims:
        if parallel_reduce is None:
            raise NotImplementedError(
                f"Named reductions not implemented for jnp.{name}()")
        result = parallel_reduce(a, dims)
    else:
        result = lax.reduce(a, init_val, op, dims)
    if initial is not None:
        result = op(lax.convert_element_type(initial, a.dtype), result)
    if keepdims:
        result = lax.expand_dims(result, pos_dims)
    return lax.convert_element_type(result, dtype or result_dtype)
Ejemplo n.º 10
0
 def fun(x):
   # lax.reduce is unlikely to ever be convertible with enable_xla=False
   return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))
Ejemplo n.º 11
0
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context):
    """Normalizes attention."""
    a = attn_weights

    def unquantized_softmax(a):
        a = lax.exp(
            a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True))
        return a.astype(dtype)

    # Quantize intermediate activations with QuantOps.
    # Currently only supports unscaled floating-point formats.
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)

    # If no params, return accurate Softmax.
    if softmax_hparams == SoftmaxHParams(None, None,
                                         None) or softmax_hparams is None:
        return unquantized_softmax(a)

    # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for
    # the entire training run, even before the global activation start step.
    if softmax_hparams.quant_hparams is not None:
        return lax.cond(quant_context.quantize_acts, quantized_softmax,
                        unquantized_softmax, a)

    # Approximated Softmax
    exp_hparams = softmax_hparams.exp_hparams
    recip_hparams = softmax_hparams.reciprocal_hparams

    # Substract max value from dimensions to be normalized.
    shape = jax.util.subvals(onp.shape(a),
                             zip(norm_dims, (1, ) * len(norm_dims)))
    dimadd = lambda x: lax.reshape(x, shape)
    # pylint: disable=protected-access
    amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max,
                      norm_dims)
    amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
    amax_singletons = dimadd(amax)
    asubmax = lax.sub(a, amax_singletons)

    # Calculate approximated exponential
    approx_exp = exponential(asubmax, dtype, exp_hparams)

    # If sum_high_bound: Upper clip bound for sum(exp(x-M)).
    asumexp = dimadd(
        lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add,
                   norm_dims))

    if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0:
        sum_low_bound = 1.
        if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract:
            sum_low_bound = 1 - onp.exp(exp_hparams.low_bound)
        asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound)

    # Approximation of reciprocal.
    arecip = reciprocal(asumexp, dtype, recip_hparams)
    return lax.mul(approx_exp, arecip).astype(dtype)
Ejemplo n.º 12
0
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        if JAX_MODE:
            from jax import lax  # pylint: disable=g-import-not-at-top
            result = lax.reduce(operands,
                                init_values=inits,
                                dimensions=axis,
                                computation=reducer)
        elif (tf.executing_eagerly()
              or not control_flow_util.GraphOrParentsInXlaContext(
                  tf1.get_default_graph())):
            result = _variadic_reduce(operands,
                                      init=inits,
                                      axis=axis,
                                      reducer=reducer)
        else:
            result = _xla_reduce(operands, inits, axis)

        if keepdims:
            axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                                 depth=ndims,
                                                 on_value=True,
                                                 off_value=False,
                                                 dtype=tf.bool),
                                      axis=0)
            in_shape = args_shape
            if not tensorshape_util.is_fully_defined(in_shape):
                in_shape = tf.shape(operands[0])
            final_shape = ps.where(axis_nhot, 1, in_shape)
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, final_shape), result)
        return result