def test_lax_dot_has_integer_inputs_in_dynamic_dot_general(
         self, mock_dot_general, lhs_distribution, rhs_distribution):
     lhs_params = QuantOps.ActHParams(input_distribution=lhs_distribution,
                                      bounds=2.0,
                                      prec=4)
     rhs_params = QuantOps.ActHParams(input_distribution=rhs_distribution,
                                      bounds=1.5,
                                      prec=4)
     lhs_act = self.lhs
     if lhs_distribution == 'positive':
         lhs_act = jnp.abs(lhs_act)
     rhs_act = self.rhs
     if rhs_distribution == 'positive':
         rhs_act = jnp.abs(rhs_act)
     quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt)
     lhs_inputs, rhs_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(lhs_inputs,
                                     prec=4,
                                     distribution=lhs_distribution)
     self.assert_is_integer_in_range(rhs_inputs,
                                     prec=4,
                                     distribution=rhs_distribution)
Ejemplo n.º 2
0
 def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec):
     get_bounds_hyper = get_bounds.GetBounds.Hyper(
         initial_bound=10.0,
         stddev_coeff=0,
         absdev_coeff=0,
         mix_coeff=0,
         granularity=quant_config.QuantGranularity.per_tensor)
     lhs_act_hparams = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=get_bounds_hyper,
         prec=lhs_prec,
         half_shift=False)
     rhs_act_hparams = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=get_bounds_hyper,
         prec=rhs_prec,
         half_shift=False)
     lhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=False, update_bounds=False, module_name='lhs')
     rhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=False, update_bounds=False, module_name='rhs')
     output = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_act_hparams,
         rhs_act_hparams=rhs_act_hparams,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt,
         lhs_get_bounds_params=lhs_get_bounds_params,
         rhs_get_bounds_params=rhs_get_bounds_params)
     return output
    def test_quantized_dynamic_dot_general_should_call_inputs_quantization(
            self,
            mock_act_fq,
            lhs_act_prec,
            rhs_act_prec,
            strategy=QuantType.fake_quant):
        mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs

        # pylint: disable=g-long-ternary
        lhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=lhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if lhs_act_prec else None
        rhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=rhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if rhs_act_prec else None
        # pylint: enable=g-long-ternary

        get_bounds_params = GetBounds.Params(update_stats=False,
                                             update_bounds=False)

        quantization.quantized_dynamic_dot_general(
            lhs_act=self.lhs_act,
            rhs_act=self.rhs_act,
            quant_type=strategy,
            dot_dimension_numbers=self.dimension_numbers,
            lhs_act_hparams=lhs_act_hparams,
            lhs_get_bounds_params=get_bounds_params,
            rhs_act_hparams=rhs_act_hparams,
            rhs_get_bounds_params=get_bounds_params,
        )
        calls = []
        for prec in [lhs_act_prec, rhs_act_prec]:
            if prec is not None:
                act_hparams = QuantOps.ActHParams(bounds=6.,
                                                  prec=prec,
                                                  input_distribution=mock.ANY,
                                                  half_shift=False)
                calls.append(
                    mock.call(mock.ANY,
                              hparams=act_hparams,
                              get_bounds_params=get_bounds_params))
        self.assertLen(calls, mock_act_fq.call_count)
        mock_act_fq.assert_has_calls(calls, any_order=True)
Ejemplo n.º 4
0
 def test_dynamic_quantized_dot_general_raises_with_mixed_dtype(self):
   lhs_params = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=2.0, prec=4, half_shift=False)
   rhs_params = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=1.5, prec=4, half_shift=False)
   lhs_act = self.lhs.astype(jnp.bfloat16)
   rhs_act = self.rhs.astype(jnp.float32)
   with self.assertRaises(TypeError):
     quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1,), (0,)), ((), ())),
         quant_type=QuantType.aqt)
 def quantized_matmul(quant_type):
     return quantization.quantized_dynamic_dot_general(
         lhs_act=self.lhs,
         rhs_act=self.rhs,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=quant_type)
Ejemplo n.º 6
0
 def test_quantized_dynamic_dot_general_no_quant(self):
   act_hparams = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=-1.0, prec=4, half_shift=False)
   lhs_act = jnp.array([[-5.0]])
   rhs_act = jnp.array([[-4.99]])
   res = quantization.quantized_dynamic_dot_general(
       lhs_act=lhs_act,
       rhs_act=rhs_act,
       quant_type=quantization.QuantType.aqt,
       lhs_act_hparams=act_hparams,
       rhs_act_hparams=act_hparams,
       lhs_get_bounds_params=None,
       rhs_get_bounds_params=None,
       dot_dimension_numbers=(((1,), (0,)), ((), ())))
   onp.testing.assert_allclose(res, lhs_act * rhs_act)
 def test_dynamic_quantized_dot_general_has_correct_dtype(
         self, input_dtype, act_prec, quant_type):
     lhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=2.0,
                                      prec=act_prec)
     rhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=1.5,
                                      prec=act_prec)
     lhs_act = self.lhs.astype(input_dtype)
     rhs_act = self.rhs.astype(input_dtype)
     output = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=quant_type)
     self.assertEqual(output.dtype, input_dtype)
Ejemplo n.º 8
0
def dot_product_attention(query,
                          key,
                          value,
                          hparams,
                          quant_context,
                          paxis_name,
                          train,
                          key_padding_mask,
                          query_padding_mask,
                          attn_mask,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs.


  Args:
    query: queries for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size,
      sequence_length, num_heads, value_channels]`.
    hparams: hyperparameters used for quantization.
    quant_context: context for quantization.
    paxis_name: axis_name to which a user `pmaps` the parent module (model),
      refer to jax.pmap() for more documentation. This arg is used for
      get_bounds acts quantization (QuantOps.create_input_fake_quant)
    train: Whether model is training.
    key_padding_mask: boolean mask indicating which elements in 'key' and
      'value' are padding. Must have a shape compatible with 'key' and 'value'.
    query_padding_mask: boolean mask indicating which elements in `query` are
      padding (True means not padding).
    attn_mask: boolean mask indicating which elements of the calculated
      attention weight matrix should be used for collecting activation
      statistics. Should have a shape broadcast-compatible with '[bs,
      sequence_length, sequence_length]'. Must have a shape broadcast-compatible
      'query'.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout.
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, sequence_length, num_heads, value_channels]`.
  """
    batch_size, query_sequence_length, num_heads, channel_size = query.shape
    key_sequence_length = key.shape[1]
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, key_sequence_length, num_heads, channel_size))
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, key_sequence_length, num_heads, channel_size))
    if key_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1))
    if query_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            query_padding_mask.shape,
            (batch_size, query_sequence_length, 1, 1))

    if attn_mask is not None:
        shape_utils.assert_shapes_compatible(
            attn_mask.shape,
            (batch_size, 1, query_sequence_length, key_sequence_length))

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )

    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)

    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, num_heads, key_sequence_length, channel_size))

    key_padding_mask_transposed = None
    query_padding_mask_transposed = None
    if key_padding_mask is not None:
        key_padding_mask_transposed = key_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            key_padding_mask_transposed.shape,
            (batch_size, 1, key_sequence_length, 1))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_k',
                           update_stats=train)(
                               key, mask=key_padding_mask_transposed)

    if query_padding_mask is not None:
        query_padding_mask_transposed = query_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            query_padding_mask_transposed.shape,
            (batch_size, 1, query_sequence_length, 1))

    key_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=key_padding_mask_transposed,
        module_name='K')

    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, num_heads, channel_size, key_sequence_length))
    value_padding_mask_transposed = None
    if key_padding_mask is not None:
        value_padding_mask_transposed = key_padding_mask.transpose(v_perm)
        shape_utils.assert_shapes_equal(
            value_padding_mask_transposed.shape,
            (batch_size, 1, 1, key_sequence_length))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_v',
                           update_stats=train)(
                               value, mask=value_padding_mask_transposed)

    value_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=value_padding_mask_transposed,
        module_name='V')

    query = query / jnp.sqrt(depth).astype(dtype)
    query = query.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        query.shape,
        (batch_size, num_heads, query_sequence_length, channel_size))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_q',
                           update_stats=train)(
                               query, mask=query_padding_mask_transposed)

    query_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=query_padding_mask_transposed,
        module_name='Q')

    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = quantized_dynamic_dot_general(
        lhs_act=query,
        rhs_act=key,
        dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t,
                                                        batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_q,
        lhs_get_bounds_params=query_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_k,
        rhs_get_bounds_params=key_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we use
    # static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        attn_weights.shape,
        (batch_size, num_heads, query_sequence_length, key_sequence_length))

    # apply attention bias: masking, dropout, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    attn_weights = softmax(attn_weights,
                           norm_dims,
                           dtype,
                           hparams.softmax,
                           quant_context=quant_context)

    # apply dropout
    if not deterministic and dropout_rate > 0.0:
        if dropout_rng is None:
            raise ValueError(
                'dropout_rng cannot be None if dropout is requested.')
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_probs',
                           update_stats=train)(attn_weights, mask=attn_mask)

    if hparams.attn_act_probs is not None:
        assert hparams.attn_act_probs.bounds == 1.0, (
            'act quantization bounds should '
            'be set to fix value 1.0 to '
            'match Softmax range.')
    probs_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=attn_mask,
        module_name='attn_probs')

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = quantized_dynamic_dot_general(
        lhs_act=attn_weights,
        rhs_act=value,
        dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t,
                                                     batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_probs,
        lhs_get_bounds_params=probs_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_v,
        rhs_get_bounds_params=value_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we
    # use static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, num_heads, query_sequence_length, channel_size))
    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, query_sequence_length, num_heads, channel_size))
    return y