Beispiel #1
0
        def get_tensor_and_scale_for_act(act, hparams, get_bounds_params):
            # We check whether activations should be quantized based on 'hparams'. If
            # so, we quantize it. If not, we return it unchanged. In either case, we
            # return a scale factor appropriate for unscaling the result of the
            # lax.dot_general.
            if hparams is not None and hparams.prec is not None:
                quant_op = QuantOps.create_input_ops(
                    act, hparams=hparams, get_bounds_params=get_bounds_params)

                scale = quant_op.get_scale_for_aqt(
                    allow_per_channel_scales=False)
                # Since only per-layer scale factors are supported, we assert that the
                # scale factors are scalars.
                shape_utils.assert_shapes_compatible(scale.shape, ())
                # TODO(malmaud): See comment on 'act_op.to_quantized' earlier in this
                # file, which applies here as well.
                act_quantized = quant_op.to_quantized(act, dtype=input_dtype)

                # TODO(shivaniagrawal): See comment in 'dot_general' above on why this
                # logic is duplicated here and in the 'else' block below.
                return lax.cond(
                    quant_op.should_quantize(),  #
                    lambda _: (act_quantized, scale),  #
                    lambda _: (act, jnp.array(1.0, dtype=SCALE_DTYPE)),  #
                    None)
            else:
                # To avoid having a separate code path for every possibility of which of
                # the two input tensors are quantized , we implement not quantizing an
                # activation tensor by simply setting its corresponding scale factor to
                # 1.0.
                return act, jnp.array(1.0, dtype=SCALE_DTYPE)
Beispiel #2
0
    def __call__(
        self,
        x,
        *,
        mask,
    ):
        """Applies a tag to track distributions.

    Args:
      x: the array to compute statistics distributions over.
      mask: boolean array indicating which elements of 'x' should be
        included in the stats calculation ('True' means to include).

    Returns:
      x unchanged. The return value can also be ignored.
    """
        if mask is None:
            mask = jnp.full(x.shape, True)
        shape_utils.assert_shapes_compatible(x.shape, mask.shape)
        mask = jnp.broadcast_to(mask, x.shape)
        channel_axis = self.channel_axis
        if channel_axis is not None:
            if not isinstance(channel_axis, Iterable):
                channel_axis = (channel_axis, )
            channel_axis = normalize_axes(channel_axis, x.ndim)
            x = _take_subset_of_axes(
                x,
                axis=channel_axis,
                num_indices_per_ax=self.num_indices_per_ax)
            mask = _take_subset_of_axes(
                mask,
                axis=channel_axis,
                num_indices_per_ax=self.num_indices_per_ax)
            reduction_axis = tuple(
                [ax for ax in range(x.ndim) if ax not in channel_axis])
        else:
            reduction_axis = None

        distr_shape = ()
        if channel_axis:
            distr_shape = tuple(d for i, d in enumerate(x.shape)
                                if i in channel_axis)

        # TODO(wanglisa): Consider adding configurability to specify which
        # statistics are collected.
        init_with_zeros = lambda shape: jnp.zeros(shape, dtype=jnp.float32)
        is_initializing = not self.has_variable('stats_tag', 'min_per_ch')
        min_per_ch = self.variable(
            'stats_tag',
            'min_per_ch',
            init_with_zeros,
            distr_shape,
        )
        max_per_ch = self.variable('stats_tag', 'max_per_ch', init_with_zeros,
                                   distr_shape)
        mean_per_ch = self.variable(
            'stats_tag',
            'mean_per_ch',
            init_with_zeros,
            distr_shape,
        )
        stddev_per_ch = self.variable(
            'stats_tag',
            'stddev_per_ch',
            init_with_zeros,
            distr_shape,
        )
        absdev_per_ch = self.variable(
            'stats_tag',
            'absdev_per_ch',
            init_with_zeros,
            distr_shape,
        )
        stddev_per_ch_uncentered = self.variable(
            'stats_tag',
            'stddev_per_ch_uncentered',
            init_with_zeros,
            distr_shape,
        )
        absdev_per_ch_uncentered = self.variable(
            'stats_tag',
            'absdev_per_ch_uncentered',
            init_with_zeros,
            distr_shape,
        )
        if self.update_stats and not is_initializing:
            min_per_ch.value = jnp.min(jnp.where(mask, x, math.inf),
                                       axis=reduction_axis)
            max_per_ch.value = jnp.max(jnp.where(mask, x, -math.inf),
                                       axis=reduction_axis)
            mean_per_ch_keepdims = stats.masked_mean(x,
                                                     mask=mask,
                                                     axis=reduction_axis,
                                                     paxis_name=None,
                                                     keepdims=True)
            mean_per_ch.value = mean_per_ch_keepdims.squeeze(
                axis=reduction_axis)
            stddev_per_ch.value = jnp.sqrt(
                stats.masked_mean((x - mean_per_ch_keepdims)**2,
                                  mask=mask,
                                  axis=reduction_axis,
                                  paxis_name=None,
                                  keepdims=False))
            absdev_per_ch.value = stats.masked_mean(
                jnp.abs(x - mean_per_ch_keepdims),
                mask=mask,
                axis=reduction_axis,
                paxis_name=None,
                keepdims=False)
            stddev_per_ch_uncentered.value = jnp.sqrt(
                stats.masked_mean(jnp.square(x),
                                  mask=mask,
                                  axis=reduction_axis,
                                  paxis_name=None,
                                  keepdims=False))
            absdev_per_ch_uncentered.value = stats.masked_mean(
                jnp.abs(x),
                mask=mask,
                axis=reduction_axis,
                paxis_name=None,
                keepdims=False)
    def __call__(
        self,
        x,
        *,
        bounds_params,
    ):
        """Compute the input batch statistics.

    Args:
      x: the input to get bounds from using statistics.
      bounds_params: parameters to compute input's statistics and bounds.

    Returns:
      Bound value (same shape as inputs).
    """

        if bounds_params.mask is not None:
            shape_utils.assert_shapes_compatible(x.shape,
                                                 bounds_params.mask.shape)

        x = jnp.asarray(x, jnp.float32)

        hyper = self.hyper
        is_initializing = not self.has_variable('get_bounds', 'stats')

        if hyper.granularity == quant_config.QuantGranularity.per_tensor:
            # Equivalently, this could be written as
            # quant_axis = tuple(range(x.ndim))
            quant_axis = None
            stats_shape = (1, ) * len(x.shape)
        elif hyper.granularity == quant_config.QuantGranularity.per_channel:
            # Quantize by aggregating activation statistics across all dimensions of
            # the activation tensor EXCEPT the last dimension, which we interpret as
            # the channel dimension. For example, in a transformer context, x might
            # have a shape corresponding to [example, token, channel], in which case
            # this aggregates activation statistics separately for each feature, where
            # for each feature it aggregates over all unmasked tokens in all examples.
            quant_axis = tuple(range(x.ndim - 1))
            stats_shape = (1, ) * (x.ndim - 1) + (x.shape[-1], )
        else:
            raise ValueError(f'Unknown granularity {hyper.granularity}')

        stats_state = self.variable('get_bounds', 'stats',
                                    Stats.stats_initializer, stats_shape)

        def bound_initializer(shape):
            return hyper.initial_bound * jnp.ones(shape)

        bounds = self.variable('get_bounds', 'bounds', bound_initializer,
                               stats_shape)

        if bounds_params.update_stats and not is_initializing:
            stats_state.value = Stats.create_updated_stats(
                stats_state.value,
                x,
                mask=bounds_params.mask,
                axis=quant_axis,
                paxis_name=bounds_params.paxis_name,
                alpha=hyper.ema_coeff,
                exclude_zeros=hyper.exclude_zeros)

        if bounds_params.update_bounds and not is_initializing:
            bounds.value = self._stats_to_bounds(stats_state.value)
            if hyper.reset_stats:
                stats_state.value = Stats.stats_initializer(stats_shape)
        return bounds.value
Beispiel #4
0
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 *,
                 padding_mask,
                 key_padding_mask,
                 segmentation=None,
                 key_segmentation=None):
        """Applies multi-head dot product attention on the input data.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` or for self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or
        None for self-attention, inn which case key/values will be derived from
        inputs_q.
      padding_mask: boolean tensor specifying query tokens that are pad token.
      key_padding_mask: boolean tensor specifying key-value tokens that are pad
        token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """
        batch_size, query_sequence_length, channel_size = inputs_q.shape
        hparams = self.hparams
        if inputs_kv is None:
            inputs_kv = inputs_q
            key_sequence_length = inputs_q.shape[1]
        else:
            key_sequence_length = inputs_kv.shape[1]
            shape_utils.assert_shapes_equal(
                inputs_kv.shape,
                (batch_size, key_sequence_length, channel_size))

        jax_precision = jax.lax.Precision.DEFAULT

        if padding_mask is not None:
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1))
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        else:
            shape_utils.assert_shapes_equal(
                key_padding_mask.shape, (batch_size, key_sequence_length, 1))
        attention_axis = self.attention_axis
        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        qkv_features = self.qkv_features
        qkv_features = qkv_features or inputs_q.shape[-1]

        num_heads = self.num_heads
        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        paxis_name = self.paxis_name
        train = self.train
        kernel_init = self.kernel_init
        bias_init = self.bias_init
        use_bias = self.use_bias
        dtype = self.dtype

        def multi_batch_dense_aqt(inputs, *, name, padding_mask):
            batch_size, sequence_length, channel_size = inputs.shape
            inputs = inputs.reshape(batch_size * sequence_length, channel_size)
            if padding_mask is not None:
                padding_mask = padding_mask.reshape(
                    batch_size * sequence_length, 1)
            out = flax_layers.DenseAqt(name=name,
                                       features=num_heads * head_dim,
                                       paxis_name=paxis_name,
                                       train=train,
                                       quant_context=self.quant_context,
                                       hparams=hparams.dense_kqv,
                                       kernel_init=kernel_init,
                                       bias_init=bias_init,
                                       use_bias=use_bias,
                                       dtype=dtype)(inputs,
                                                    padding_mask=padding_mask)
            return out.reshape(batch_size, sequence_length, num_heads,
                               head_dim)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, sequence_length, n_heads, n_features_per_head]
        query = multi_batch_dense_aqt(inputs_q,
                                      name='query',
                                      padding_mask=padding_mask)
        key = multi_batch_dense_aqt(inputs_kv,
                                    name='key',
                                    padding_mask=key_padding_mask)
        value = multi_batch_dense_aqt(inputs_kv,
                                      name='value',
                                      padding_mask=key_padding_mask)
        is_cache_initialized = False
        if self.decode:
            is_cache_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_cache_initialized:
                expected_shape = list(cached_key.value.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                cshape = cached_key.value.shape
                indices = [0] * len(cshape)
                i = cache_index.value
                attn_size = onp.prod(onp.take(cshape, attention_axis))

                *batch_dims, max_length, num_heads, depth_per_head = (  # pylint: disable=unused-variable
                    cached_key.value.shape)
                indices = (0, ) * len(batch_dims) + (i, 0, 0)

                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                one = jnp.array(1, jnp.int32)
                cache_index.value = cache_index.value + one
                cached_key.value = key
                cached_value.value = value

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(max_length) < cache_index.value), cshape[:2])
                key_padding_mask = key_padding_mask.astype(
                    jnp.float32)[Ellipsis, None]

        # create attention masks
        mask_components = []
        if self.causal_mask:
            if self.decode and is_cache_initialized:
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(onp.take(key.shape, attention_axis))
                attn_size = onp.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.int32)
                mask = ii < cache_index.value
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))
        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            attn_padding_mask = make_padding_mask(
                padding_mask_query=padding_mask,
                padding_mask_key=key_padding_mask,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis)
            mask_components.append(attn_padding_mask)
        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)
        attention_mask = None
        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)
            attention_mask = attention_mask.astype(jnp.bool_)

            # attention mask in the form of attention bias
            attention_bias = jnp.where(
                attention_mask,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # Add an extra dimension to the mask corresponding to the head
        # dimension. eg, if inputs_q has shape [batch_size, sequence_length,
        # n_features], then padding_mask will have a shape
        # [batch_size, sequence_length, 1] and query will have shape
        # [batch_size, sequence_length, n_heads, n_features_per_head].
        # We create query_padding_mask with shape [batch_size, sequence_length,
        # 1, 1] to be broadcast-compatible with 'query'.
        if padding_mask is not None:
            padding_mask = padding_mask[Ellipsis, None]
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1, 1))
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[Ellipsis, None]
            # During prediction, the key padding mask is only going to be
            # broadcast-compatible with the key.
            shape_utils.assert_shapes_compatible(
                key_padding_mask.shape,
                (batch_size, key_sequence_length, 1, 1))

        # apply attention
        attention_fn = self.attention_fn
        dropout_rate = self.dropout_rate
        broadcast_dropout = self.broadcast_dropout
        deterministic = self.deterministic
        if not deterministic and self.dropout_rate > 0.0:
            dropout_rng = self.make_rng('dropout')
        else:
            dropout_rng = None
        x = attention_fn(  # pylint: disable=redundant-keyword-arg
            query=query,
            key=key,
            value=value,
            hparams=hparams.attn_acts,
            paxis_name=paxis_name,
            train=train,
            quant_context=self.quant_context,
            dtype=dtype,
            axis=attention_axis,
            bias=attention_bias,
            precision=jax_precision,
            dropout_rng=dropout_rng,
            dropout_rate=dropout_rate,
            broadcast_dropout=broadcast_dropout,
            deterministic=deterministic,
            query_padding_mask=padding_mask,
            key_padding_mask=key_padding_mask,
            attn_mask=attention_mask)
        shape_utils.assert_shapes_equal(
            x.shape, (batch_size, query_sequence_length, num_heads, head_dim))
        x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim)
        if padding_mask is not None:
            padding_mask = padding_mask.reshape(
                batch_size * query_sequence_length, 1)
        # back to the original inputs dimensions
        out = flax_layers.DenseAqt(features=channel_size,
                                   hparams=hparams.dense_out,
                                   quant_context=self.quant_context,
                                   paxis_name=paxis_name,
                                   train=train,
                                   kernel_init=kernel_init,
                                   bias_init=bias_init,
                                   use_bias=use_bias,
                                   dtype=dtype,
                                   name='dense_out')(x,
                                                     padding_mask=padding_mask)
        shape_utils.assert_shapes_equal(
            out.shape, (batch_size * query_sequence_length, channel_size))
        out = out.reshape(batch_size, query_sequence_length, channel_size)
        return out
Beispiel #5
0
def dot_product_attention(query,
                          key,
                          value,
                          hparams,
                          quant_context,
                          paxis_name,
                          train,
                          key_padding_mask,
                          query_padding_mask,
                          attn_mask,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs.


  Args:
    query: queries for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size,
      sequence_length, num_heads, value_channels]`.
    hparams: hyperparameters used for quantization.
    quant_context: context for quantization.
    paxis_name: axis_name to which a user `pmaps` the parent module (model),
      refer to jax.pmap() for more documentation. This arg is used for
      get_bounds acts quantization (QuantOps.create_input_fake_quant)
    train: Whether model is training.
    key_padding_mask: boolean mask indicating which elements in 'key' and
      'value' are padding. Must have a shape compatible with 'key' and 'value'.
    query_padding_mask: boolean mask indicating which elements in `query` are
      padding (True means not padding).
    attn_mask: boolean mask indicating which elements of the calculated
      attention weight matrix should be used for collecting activation
      statistics. Should have a shape broadcast-compatible with '[bs,
      sequence_length, sequence_length]'. Must have a shape broadcast-compatible
      'query'.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout.
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, sequence_length, num_heads, value_channels]`.
  """
    batch_size, query_sequence_length, num_heads, channel_size = query.shape
    key_sequence_length = key.shape[1]
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, key_sequence_length, num_heads, channel_size))
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, key_sequence_length, num_heads, channel_size))
    if key_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1))
    if query_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            query_padding_mask.shape,
            (batch_size, query_sequence_length, 1, 1))

    if attn_mask is not None:
        shape_utils.assert_shapes_compatible(
            attn_mask.shape,
            (batch_size, 1, query_sequence_length, key_sequence_length))

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )

    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)

    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, num_heads, key_sequence_length, channel_size))

    key_padding_mask_transposed = None
    query_padding_mask_transposed = None
    if key_padding_mask is not None:
        key_padding_mask_transposed = key_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            key_padding_mask_transposed.shape,
            (batch_size, 1, key_sequence_length, 1))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_k',
                           update_stats=train)(
                               key, mask=key_padding_mask_transposed)

    if query_padding_mask is not None:
        query_padding_mask_transposed = query_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            query_padding_mask_transposed.shape,
            (batch_size, 1, query_sequence_length, 1))

    key_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=key_padding_mask_transposed,
        module_name='K')

    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, num_heads, channel_size, key_sequence_length))
    value_padding_mask_transposed = None
    if key_padding_mask is not None:
        value_padding_mask_transposed = key_padding_mask.transpose(v_perm)
        shape_utils.assert_shapes_equal(
            value_padding_mask_transposed.shape,
            (batch_size, 1, 1, key_sequence_length))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_v',
                           update_stats=train)(
                               value, mask=value_padding_mask_transposed)

    value_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=value_padding_mask_transposed,
        module_name='V')

    query = query / jnp.sqrt(depth).astype(dtype)
    query = query.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        query.shape,
        (batch_size, num_heads, query_sequence_length, channel_size))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_q',
                           update_stats=train)(
                               query, mask=query_padding_mask_transposed)

    query_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=query_padding_mask_transposed,
        module_name='Q')

    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = quantized_dynamic_dot_general(
        lhs_act=query,
        rhs_act=key,
        dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t,
                                                        batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_q,
        lhs_get_bounds_params=query_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_k,
        rhs_get_bounds_params=key_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we use
    # static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        attn_weights.shape,
        (batch_size, num_heads, query_sequence_length, key_sequence_length))

    # apply attention bias: masking, dropout, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    attn_weights = softmax(attn_weights,
                           norm_dims,
                           dtype,
                           hparams.softmax,
                           quant_context=quant_context)

    # apply dropout
    if not deterministic and dropout_rate > 0.0:
        if dropout_rng is None:
            raise ValueError(
                'dropout_rng cannot be None if dropout is requested.')
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_probs',
                           update_stats=train)(attn_weights, mask=attn_mask)

    if hparams.attn_act_probs is not None:
        assert hparams.attn_act_probs.bounds == 1.0, (
            'act quantization bounds should '
            'be set to fix value 1.0 to '
            'match Softmax range.')
    probs_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=attn_mask,
        module_name='attn_probs')

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = quantized_dynamic_dot_general(
        lhs_act=attn_weights,
        rhs_act=value,
        dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t,
                                                     batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_probs,
        lhs_get_bounds_params=probs_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_v,
        rhs_get_bounds_params=value_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we
    # use static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, num_heads, query_sequence_length, channel_size))
    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, query_sequence_length, num_heads, channel_size))
    return y
Beispiel #6
0
    def create_updated_stats(cls,
                             stats,
                             samples,
                             *,
                             axis=None,
                             paxis_name=None,
                             alpha=None,
                             mask=None,
                             exclude_zeros=False):
        """Create a new Stats instance that represents the updated statistics.

    Since flax.struct.dataclass objects are frozen, this method creates a new
    instance of Stats with updated stats and returns it.

    Args:
      stats: A Stats dataclass object to be updated.
      samples: An array to update the current statistics with.
      axis: axis to average input samples over, e.g. to calculate stats per
        channel.
      paxis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names.
      alpha: Smoothing parameter to use for moving average. If None, will use
        1/n, where n is the stat count.
      mask: Optional boolean tensor of the same shape as 'samples' specifying
        which values of 'samples' to use as part of the bounds calculation.
        'True' indicates the corresponding value from 'samples' should be used.
        If None, all values are used.
      exclude_zeros: Whether to exclude zeros in samples when computing
        statistics, e.g. when calculating mean absolute values.

    Returns:
      A new Stats instance with updated stats and count.
    """

        if mask is None:
            mask = jnp.full(samples.shape, True)
        shape_utils.assert_shapes_compatible(samples.shape, mask.shape)
        mask = jnp.broadcast_to(mask, samples.shape)
        if exclude_zeros:
            # Where samples are zero, set mask to False. This way they won't be
            # included in statistics.
            mask = mask & (samples != 0)

        def _moving_avg(old_avg, new_val, masked_reduction_fn):
            masked_new_val_reduced = masked_reduction_fn(new_val,
                                                         mask=mask,
                                                         axis=axis,
                                                         paxis_name=paxis_name,
                                                         keepdims=True)
            valid_mask = jnp.isfinite(masked_new_val_reduced)
            # Only update average where means are valid, so set deltas corresponding
            # to invalid entries to 0.
            delta = jnp.where(valid_mask, masked_new_val_reduced - old_avg, 0)
            # TODO(lew): This is slightly incorrect, alpha should be proportional to
            # the mask size.
            new_avg = old_avg + alpha * delta
            return new_avg

        new_n = stats.n + 1
        if alpha is None:
            alpha = 1. / new_n

        new_mean = _moving_avg(stats.mean,
                               samples,
                               masked_reduction_fn=masked_mean)
        new_mean_abs = _moving_avg(stats.mean_abs,
                                   jnp.abs(samples),
                                   masked_reduction_fn=masked_mean)
        new_mean_sq = _moving_avg(stats.mean_sq,
                                  jnp.square(samples),
                                  masked_reduction_fn=masked_mean)
        new_mean_batch_minimum = _moving_avg(
            stats.mean_batch_minimum,
            samples,
            masked_reduction_fn=masked_mean_of_min)
        new_mean_batch_maximum = _moving_avg(
            stats.mean_batch_maximum,
            samples,
            masked_reduction_fn=masked_mean_of_max)
        return cls(n=new_n,
                   mean=new_mean,
                   mean_abs=new_mean_abs,
                   mean_sq=new_mean_sq,
                   mean_batch_minimum=new_mean_batch_minimum,
                   mean_batch_maximum=new_mean_batch_maximum)