def __call__(
        self,
        inputs,
        *,
        padding_mask,
    ):
        """Applies a linear transformation to the inputs with optional quantization.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Args:
      inputs: The nd-array to be transformed.
      padding_mask: boolean tensor of the same shape as 'inputs' specifying
        which values of 'inputs' to use as part of the bounds calculation.
        'True' indicates the corresponding value from 'inputs' should be used.
        If None, all values are used.

    Returns:
      The transformed input.
    """
        batch_size = inputs.shape[0]
        if padding_mask is not None:
            shape_utils.assert_shapes_equal(padding_mask.shape,
                                            (batch_size, 1))
        # TODO(wanglisa): Replace fake quant with AQT.

        if self.quant_context.collect_acts_stats:
            stats_tag.StatsTag(channel_axis=-1,
                               name='inputs',
                               update_stats=self.train)(inputs,
                                                        mask=padding_mask)
        hparams = self.hparams
        if (hparams.weight_prec is not None
                and isinstance(hparams.weight_prec, int)
                and hparams.weight_prec > 8):
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )

        kernel = self.param('kernel', self.kernel_init,
                            (inputs.shape[-1], self.features))

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        get_bounds_params = get_bounds.GetBounds.Params(
            update_bounds=self.quant_context.update_bounds,
            update_stats=self.train,
            paxis_name=self.paxis_name,
            mask=padding_mask)

        weight_quant_granularity = hparams.weight_quant_granularity
        # kernel.shape = (channels_in, channels_out)
        if weight_quant_granularity == quant_config.QuantGranularity.per_channel:
            # Compute scale factors by reducing over the rows of the weight matrix,
            # resulting in one scale factor per column. This results in one scale
            # factor per output channel.
            expected_scale_shape = (1, self.features)
            weight_quant_axis = (0, )
        elif weight_quant_granularity == quant_config.QuantGranularity.per_tensor:
            # Compute a single scale factor for the entire weight matrix.
            expected_scale_shape = (1, 1)
            weight_quant_axis = None
        else:
            raise ValueError(
                f'Invalid quantization granularity {weight_quant_granularity}.'
            )

        weight_params = QuantOps.WeightParams(
            prec=hparams.weight_prec,
            half_shift=hparams.weight_half_shift,
            axis=weight_quant_axis,
            expected_scale_shape=expected_scale_shape)

        # TODO(wanglisa): add option to control when scale is being recomputed

        # matmul
        contracting_dims = ((inputs.ndim - 1, ), (0, ))
        # `((lhs_contracting_dims, rhs_contracting_dims),
        batch_dims = ((), ())  # (lhs_batch_dims, rhs_batch_dims))`
        y = quantization.quantized_dot_general(
            act=inputs,
            w=kernel,
            quant_type=hparams.quant_type,
            weight_params=weight_params,
            act_hparams=hparams.quant_act,
            get_bounds_params=get_bounds_params,
            dimension_numbers=(contracting_dims, batch_dims),
            dot_precision=self.precision,
            prefer_int8_to_int32_dot=self.quant_context.
            prefer_int8_to_int32_dot)

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            # (batch_size, features)
            y = y + bias[jnp.newaxis, :]
        return y
Beispiel #2
0
def dot_product_attention(query,
                          key,
                          value,
                          hparams,
                          quant_context,
                          paxis_name,
                          train,
                          key_padding_mask,
                          query_padding_mask,
                          attn_mask,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs.


  Args:
    query: queries for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size,
      sequence_length, num_heads, value_channels]`.
    hparams: hyperparameters used for quantization.
    quant_context: context for quantization.
    paxis_name: axis_name to which a user `pmaps` the parent module (model),
      refer to jax.pmap() for more documentation. This arg is used for
      get_bounds acts quantization (QuantOps.create_input_fake_quant)
    train: Whether model is training.
    key_padding_mask: boolean mask indicating which elements in 'key' and
      'value' are padding. Must have a shape compatible with 'key' and 'value'.
    query_padding_mask: boolean mask indicating which elements in `query` are
      padding (True means not padding).
    attn_mask: boolean mask indicating which elements of the calculated
      attention weight matrix should be used for collecting activation
      statistics. Should have a shape broadcast-compatible with '[bs,
      sequence_length, sequence_length]'. Must have a shape broadcast-compatible
      'query'.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout.
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, sequence_length, num_heads, value_channels]`.
  """
    batch_size, query_sequence_length, num_heads, channel_size = query.shape
    key_sequence_length = key.shape[1]
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, key_sequence_length, num_heads, channel_size))
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, key_sequence_length, num_heads, channel_size))
    if key_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1))
    if query_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            query_padding_mask.shape,
            (batch_size, query_sequence_length, 1, 1))

    if attn_mask is not None:
        shape_utils.assert_shapes_compatible(
            attn_mask.shape,
            (batch_size, 1, query_sequence_length, key_sequence_length))

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )

    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)

    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, num_heads, key_sequence_length, channel_size))

    key_padding_mask_transposed = None
    query_padding_mask_transposed = None
    if key_padding_mask is not None:
        key_padding_mask_transposed = key_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            key_padding_mask_transposed.shape,
            (batch_size, 1, key_sequence_length, 1))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_k',
                           update_stats=train)(
                               key, mask=key_padding_mask_transposed)

    if query_padding_mask is not None:
        query_padding_mask_transposed = query_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            query_padding_mask_transposed.shape,
            (batch_size, 1, query_sequence_length, 1))

    key_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=key_padding_mask_transposed,
        module_name='K')

    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, num_heads, channel_size, key_sequence_length))
    value_padding_mask_transposed = None
    if key_padding_mask is not None:
        value_padding_mask_transposed = key_padding_mask.transpose(v_perm)
        shape_utils.assert_shapes_equal(
            value_padding_mask_transposed.shape,
            (batch_size, 1, 1, key_sequence_length))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_v',
                           update_stats=train)(
                               value, mask=value_padding_mask_transposed)

    value_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=value_padding_mask_transposed,
        module_name='V')

    query = query / jnp.sqrt(depth).astype(dtype)
    query = query.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        query.shape,
        (batch_size, num_heads, query_sequence_length, channel_size))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_q',
                           update_stats=train)(
                               query, mask=query_padding_mask_transposed)

    query_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=query_padding_mask_transposed,
        module_name='Q')

    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = quantized_dynamic_dot_general(
        lhs_act=query,
        rhs_act=key,
        dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t,
                                                        batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_q,
        lhs_get_bounds_params=query_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_k,
        rhs_get_bounds_params=key_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we use
    # static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        attn_weights.shape,
        (batch_size, num_heads, query_sequence_length, key_sequence_length))

    # apply attention bias: masking, dropout, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    attn_weights = softmax(attn_weights,
                           norm_dims,
                           dtype,
                           hparams.softmax,
                           quant_context=quant_context)

    # apply dropout
    if not deterministic and dropout_rate > 0.0:
        if dropout_rng is None:
            raise ValueError(
                'dropout_rng cannot be None if dropout is requested.')
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_probs',
                           update_stats=train)(attn_weights, mask=attn_mask)

    if hparams.attn_act_probs is not None:
        assert hparams.attn_act_probs.bounds == 1.0, (
            'act quantization bounds should '
            'be set to fix value 1.0 to '
            'match Softmax range.')
    probs_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=attn_mask,
        module_name='attn_probs')

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = quantized_dynamic_dot_general(
        lhs_act=attn_weights,
        rhs_act=value,
        dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t,
                                                     batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_probs,
        lhs_get_bounds_params=probs_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_v,
        rhs_get_bounds_params=value_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we
    # use static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, num_heads, query_sequence_length, channel_size))
    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, query_sequence_length, num_heads, channel_size))
    return y