Ejemplo n.º 1
0
    def __call__(
        self,
        inputs,
        *,
        padding_mask,
    ):
        """Applies Transformer MlpBlock module."""

        batch_size, sequence_length, channel_size = inputs.shape
        inputs = inputs.reshape((batch_size * sequence_length, channel_size))
        shape_utils.assert_shapes_equal(padding_mask.shape,
                                        (batch_size, sequence_length, 1))
        padding_mask = padding_mask.reshape((batch_size * sequence_length, 1))
        x = aqt_flax_layers.DenseAqt(features=self.mlp_dim,
                                     dtype=self.dtype,
                                     paxis_name='batch',
                                     train=self.train,
                                     quant_context=self.quant_context,
                                     hparams=self.hparams.dense_1,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init,
                                     name='dense_1')(inputs,
                                                     padding_mask=padding_mask)

        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(
            x, deterministic=self.deterministic)

        output = aqt_flax_layers.DenseAqt(
            # We have relu before this layer, x would only contain positive values.
            features=channel_size,
            dtype=self.dtype,
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            hparams=self.hparams.dense_2,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            name='dense_2')(x, padding_mask=padding_mask)

        output = nn.Dropout(rate=self.dropout_rate)(
            output, deterministic=self.deterministic)
        output = output.reshape((batch_size, sequence_length, channel_size))
        return output
Ejemplo n.º 2
0
  def attend(self, query, padding_mask,
             **unused_kwargs):
    """Attend over the embedding using a query array.

    Args:
      query: array with last dimension equal the feature depth `features` of the
        embedding.
      padding_mask: boolean mask indicating which elements of 'query' are
        padding. Used for calculating activation statistics for the dynamic
        bounds quantization algorithm.
      **unused_kwargs: unused arguments passed from the apply method.

    Returns:
      An array with final dim `num_embeddings` corresponding to the batched
      inner-product of the array of query vectors against each embedding.
      Commonly used for weight-sharing between embeddings and logit transform
      in NLP models.
    """
    del unused_kwargs

    batch_size, channel_size = query.shape  # pylint: disable=unused-variable

    if padding_mask is not None:
      shape_utils.assert_shapes_equal(padding_mask.shape, (batch_size, 1))

    embedding = self.embedding
    embedding = jnp.asarray(embedding, self.dtype)

    # TODO(malmaud): Remove the 'mask' field from this struct so we can
    # make this struct a hyperparameter of the EncoderAqt class.
    get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=self.quant_context.update_bounds,
        update_stats=self.train,
        paxis_name=self.paxis_name,
        mask=padding_mask,
        module_name='logits')

    out = self.quantized_dot(
        act=query,
        w=jnp.transpose(embedding),
        get_bounds_params=get_bounds_params)

    return out
Ejemplo n.º 3
0
    def __call__(
        self,
        inputs,
        targets,
        inputs_positions=None,
        targets_positions=None,
        inputs_segmentation=None,
        targets_segmentation=None,
    ):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      targets: target data
      inputs_positions: input subsequence positions for packed examples.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      output of a transformer decoder.

    """
        batch_size, sequence_length = inputs.shape
        assert batch_size == targets.shape[
            0], 'Input and targets must have the same batch size'

        src_padding_mask = (inputs > 0)[Ellipsis, None]
        shape_utils.assert_shapes_equal(src_padding_mask.shape,
                                        (batch_size, sequence_length, 1))

        encoded = self.encode(inputs,
                              inputs_positions=inputs_positions,
                              inputs_segmentation=inputs_segmentation)

        logits = self.decode(encoded,
                             src_padding_mask,
                             targets,
                             targets_positions=targets_positions,
                             inputs_segmentation=inputs_segmentation,
                             targets_segmentation=targets_segmentation,
                             tgt_padding_mask=None)
        return logits.astype(jnp.float32) if self.use_bfloat16 else logits
Ejemplo n.º 4
0
    def __call__(
        self,
        inputs,
    ):
        """Embeds the inputs along the last dimension.

    Args:
      inputs: input data, all dimensions are considered batch dimensions.

    Returns:
      Output which is embedded input data.  The output shape follows the input,
      with an additional `features` dimension appended.
    """
        batch_size, sequence_length = inputs.shape
        if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
            raise ValueError(
                'Input type must be an integer or unsigned integer.')
        embedding = self.embedding

        embedding = jnp.asarray(embedding, self.dtype)

        hparams = self.hparams
        # Initialize state for stats and bounds, this would be required for logits
        # in the following method attend.
        if hparams.quant_act is not None and isinstance(
                hparams.quant_act.bounds, get_bounds.GetBounds.Hyper):
            self.get_bounds_logits(
                inputs,
                bounds_params=get_bounds.GetBounds.Params(update_stats=False,
                                                          update_bounds=False,
                                                          paxis_name=None),
            )

        weight_prec = hparams.weight_prec
        weight_half_shift = hparams.weight_half_shift
        if weight_prec is not None:
            quantized_type = hparams.quant_type.to_jax_type()
            # In contrast to all other scale factor calculations in this module, we
            # compute per-row instead of per-column (ie, per-output-channel) scale
            # factors here. This is because the embedding matrix might be shared with
            # the output (logit) layer of the transformer, in which case the
            # *transpose* of the embedding matrix will be used as the weight matrix in
            # a mamtul. The per-row scale factors used here would thus correspond to
            # using per-column (because of the transpose) scale factors used by the
            # weight matrix in the logits layer, which is what we need for AQT.
            embedding_quant_ops = QuantOps.create_weights_ops(
                embedding,
                weight_params=QuantOps.WeightParams(
                    prec=weight_prec, axis=(1, ),
                    half_shift=weight_half_shift))
            embedding_quant_ops.assert_scale_shape_is(
                shape=(self.num_embeddings, 1))

            quantized_embedding = embedding_quant_ops.to_quantized(
                embedding, dtype=quantized_type)
            quantized_embedded_inputs = quantized_embedding[inputs]
            # Since the embedding matrix 'quantized_embedding' is gathered based on
            # 'inputs' to produce the embedding tensor, we apply the same gathering to
            # the per-row scale factors of the embedding matrix so the scale factors
            # will broadcast appropriately in the subsequent call to 'to_quantized'.
            # TODO(malmaud): As part of quantization.py refactor, change
            # 'get_scale_for_aqt' to cleanly support this and hence avoid the need to
            # directly access a protected member of QuantOps.
            scale = embedding_quant_ops._scale[inputs]  # pylint: disable=protected-access
            shape_utils.assert_shapes_equal(scale.shape,
                                            (batch_size, sequence_length, 1))
            shape_utils.assert_shapes_equal(
                quantized_embedded_inputs.shape,
                (batch_size, sequence_length, self.features))
            embedded_inputs = (quantized_embedded_inputs / scale).astype(
                self.dtype)
        else:
            embedded_inputs = embedding[inputs]
        shape_utils.assert_shapes_equal(
            embedded_inputs.shape,
            (batch_size, sequence_length, self.features))
        return embedded_inputs
Ejemplo n.º 5
0
    def __call__(
        self,
        inputs,
        *,
        padding_mask,
    ):
        """Applies a linear transformation to the inputs with optional quantization.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Args:
      inputs: The nd-array to be transformed.
      padding_mask: boolean tensor of the same shape as 'inputs' specifying
        which values of 'inputs' to use as part of the bounds calculation.
        'True' indicates the corresponding value from 'inputs' should be used.
        If None, all values are used.

    Returns:
      The transformed input.
    """
        batch_size = inputs.shape[0]
        if padding_mask is not None:
            shape_utils.assert_shapes_equal(padding_mask.shape,
                                            (batch_size, 1))
        # TODO(wanglisa): Replace fake quant with AQT.

        if self.quant_context.collect_acts_stats:
            stats_tag.StatsTag(channel_axis=-1,
                               name='inputs',
                               update_stats=self.train)(inputs,
                                                        mask=padding_mask)
        hparams = self.hparams
        if (hparams.weight_prec is not None
                and isinstance(hparams.weight_prec, int)
                and hparams.weight_prec > 8):
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )

        kernel = self.param('kernel', self.kernel_init,
                            (inputs.shape[-1], self.features))

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        get_bounds_params = get_bounds.GetBounds.Params(
            update_bounds=self.quant_context.update_bounds,
            update_stats=self.train,
            paxis_name=self.paxis_name,
            mask=padding_mask)

        weight_quant_granularity = hparams.weight_quant_granularity
        # kernel.shape = (channels_in, channels_out)
        if weight_quant_granularity == quant_config.QuantGranularity.per_channel:
            # Compute scale factors by reducing over the rows of the weight matrix,
            # resulting in one scale factor per column. This results in one scale
            # factor per output channel.
            expected_scale_shape = (1, self.features)
            weight_quant_axis = (0, )
        elif weight_quant_granularity == quant_config.QuantGranularity.per_tensor:
            # Compute a single scale factor for the entire weight matrix.
            expected_scale_shape = (1, 1)
            weight_quant_axis = None
        else:
            raise ValueError(
                f'Invalid quantization granularity {weight_quant_granularity}.'
            )

        weight_params = QuantOps.WeightParams(
            prec=hparams.weight_prec,
            half_shift=hparams.weight_half_shift,
            axis=weight_quant_axis,
            expected_scale_shape=expected_scale_shape)

        # TODO(wanglisa): add option to control when scale is being recomputed

        # matmul
        contracting_dims = ((inputs.ndim - 1, ), (0, ))
        # `((lhs_contracting_dims, rhs_contracting_dims),
        batch_dims = ((), ())  # (lhs_batch_dims, rhs_batch_dims))`
        y = quantization.quantized_dot_general(
            act=inputs,
            w=kernel,
            quant_type=hparams.quant_type,
            weight_params=weight_params,
            act_hparams=hparams.quant_act,
            get_bounds_params=get_bounds_params,
            dimension_numbers=(contracting_dims, batch_dims),
            dot_precision=self.precision,
            prefer_int8_to_int32_dot=self.quant_context.
            prefer_int8_to_int32_dot)

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            # (batch_size, features)
            y = y + bias[jnp.newaxis, :]
        return y
Ejemplo n.º 6
0
def quantized_dot(*,
                  w,
                  act,
                  quant_type,
                  weight_params,
                  act_hparams,
                  get_bounds_params,
                  prefer_int8_to_int32_dot,
                  dot_precision=None):
    """LAX dot with optionally quantized weights and activations.

  Wraps LAX's `Dot
  <https://github.com/google/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L632`_
  operator.

  Args:
    w: an array representing weights
    act: an array representing activations
    quant_type: quantization strategy
    weight_params: QuantOps.WeighstParams instance for describing weights
      quantization.
    act_hparams: Optional activation quantization hyperparamers; instance of
      QuantOps.ActHParams. None would mean no activation quantization.
    get_bounds_params: Optional get bounds params for auto activation
      quantization; instance of GetBounds.Params.
    prefer_int8_to_int32_dot:  Whether to feed lax.dot inputs with an int8
      dtype and accumulate to int32 dtype if quantizing to 8bits or 4bits. If
      False, inputs are always foating-point.
    dot_precision: Optional. Either ``None``, which means the default precision
      for the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).

  Returns:
    An array containing the result with the same dtype as 'w' and 'act'.

  Raises:
    RuntimeError: 'quant_type' had an unrecognized value.
    TypeError: 'act' and 'w' has different input types.
    ValueError: Shapes of 'act' and 'w' not compatible with quant_type.
  """
    # This code was initially expanded from
    # https://github.com/google/jax/blob/f65a327c764406db45e95048dfe09209d8ef6d37/jax/_src/lax/lax.py#L632
    # We keep the original return-value semantics of lax.dot, which this wraps. In
    # particular, the type of the return value of quantized_dot is the same as the
    # type of the inputs. That means that if the inputs are bfloat16, then the
    # return type of this function will also be bfloat16 even though on current
    # TPUs the underlying bf16*bf16 matrix-multiplication accumulates results to
    # float32. This is potentially undesirable since the user might want the raw
    # float32 result, but it ultimately stems from a limitation of the HLO 'dot'
    # instruction. If that instruction updates to support user-specified output
    # types, we could update quantized_dot accordingly to take a dtype argument to
    # control the return value type. This applies equally to
    # quantized_dynamic_dot_general.
    if not (1 <= act.ndim <= 2 and 1 <= w.ndim <= 2
            and act.shape[-1] == w.shape[0]):
        raise ValueError('Incompatible shapes for dot: got {} and {}.'.format(
            act.shape, w.shape))
    dot_dimension_numbers = (((act.ndim - 1, ), (0, )), ((), ()))
    if quant_type == QuantType.aqt:
        # Let 's' be activation scales and 't' be weight scales. We implement
        # matmul(RoundAndClip(w*s), RoundAndClip(s^-1 * w * t)) *t^-1. In the
        # comments below, we refer to this terminology.
        # lax.dot accepts any combination of 1d and 2d arguments for its lhs and rhs
        # input. To simplify the AQT implementation, we only accept 2d arguments for
        # now.
        if w.ndim != 2 or act.ndim != 2:
            raise ValueError(
                'AQT is currently only implemented for matrix*matrix operations'
            )
        num_input_channels = act.shape[1]
        num_output_channels = w.shape[1]

        # The ValueError raised in the guard at the beginning of this function
        # should have already checked that the weight matrix has a number of rows
        # equal to the number of channels in the activation.
        assert w.shape[0] == num_input_channels

        # We carry out all intermediate calculations using the same dtype as the
        # inputs. We want to be careful to not take a model configured to be trained
        # in bf16 and accidentally train it in fp32 by virtue of the scale dtype
        # being fp32.
        if act.dtype != w.dtype:
            raise TypeError(
                f'Activations and weight must have the same dtype, but got {act.dtype} and {w.dtype}'
            )
        input_dtype = act.dtype

        is_act_quantized = False
        # In this case, activations will be quantized at some point during training
        # (either now or later) and so we need to gather activation statistics by
        # calling 'QuantOps.create_input_ops', even if activations are not being
        # quantized on this particular training step (see b/174516400).
        if act_hparams is not None and act_hparams.prec is not None:
            # Calculate 's', the per-column scale factor on activations.
            act_op = QuantOps.create_input_ops(
                act, hparams=act_hparams, get_bounds_params=get_bounds_params)
            is_act_quantized = act_op.should_quantize()
            # Quantize activation matrix by computing RoundAndClip(w*s)

            # TODO(malmaud): We have to cast quantized activations to an fp format
            # instead of int8 since int8 matmul with int32 accumulation is not yet
            # supported in XLA (and therefore in Jax). See b/170293520. We keep
            # 'act_quantized' in whatever it's original fp format was, typically bf16
            # or fp32, to follow what Fakequant does (see the type cast at the end of
            # QuantOpts.fake_quant).
            act_quantized = act_op.to_quantized(act, dtype=input_dtype)

            # Now calculate s^-1.  First we extract s, the activation scale factor,
            # into a  variable called 'act_scale'. We extract it from 'act_op', the
            # QuantOps instance that calculated the scale factors for the activation
            # matrix.
            act_scale = act_op.get_scale_for_aqt(allow_per_channel_scales=True)
            # act_scale should either be a scalar, corresponding to per-layer
            # quantization, or a matrix with shape (1, num_input_channels),
            # corresponding to per-activation-channel scale factors.
            if act_scale.ndim != 0:
                shape_utils.assert_shapes_equal(act_scale.shape,
                                                (1, num_input_channels))
                # 'w' has one row per column of 'act_scale'. To scale each row of 'w' by
                # the inverse of the corresponding column in 'act_scale', we first have
                # to reshape 'act_scale' from (1, num_input_channels) to
                # (num_input_channels, 1) so the scale factors will broadcast
                # appropriately across the columns of 'w'.
                act_scale = act_scale.reshape(num_input_channels, 1)
            # Now we calculate s^-1 * w.
            w_scaled_rows = ((1 / act_scale) * w).astype(input_dtype)

            # TODO(shivaniagrawal): This section repeats code from the 'else' block.
            # The code is repeated twice because quantization can either be disabled
            # dynamically by setting the clipping bound to -1 (see comments on
            # 'should_quantize'), or statically by setting the 'prec' hyperparameter
            # to None. This block deals with the dynamic case (hence necessitating the
            # use of the dynamic 'lax.cond') while the 'else' block handles the static
            # case. Ultimately, we should unify them.
            act_quantized, w_scaled_rows = lax.cond(
                is_act_quantized, lambda _: (act_quantized, w_scaled_rows),
                lambda _: (act, w), None)
        else:
            # In this case, activations are not being quantized; only weights. There
            # is no need to absorb activation scales into the rows of the weight
            # matrix so 'w_scaled_rows' can just be set to the original weight matrix.
            act_quantized = act
            w_scaled_rows = w

        is_weight_quantized = False
        if weight_params is not None and weight_params.prec is not None:
            is_weight_quantized = True
            # Calculate 'r' from (s^-1) * w
            weight_op = QuantOps.create_weights_ops(
                w_scaled_rows, weight_params=weight_params)
            weight_scale = weight_op.get_scale_for_aqt(
                allow_per_channel_scales=True)
            # Similar to 'act_scale' above, the weight_scale can either be a single
            # scalar or be a matrix with shape (1, num_output_channels), corresponding
            # to a per-channel scale factor for the weight matrix. We verify it here.
            if weight_scale.ndim != 0:
                shape_utils.assert_shapes_equal(weight_scale.shape,
                                                (1, num_output_channels))

            # Quantize weight matrix by calculating RoundAndClip(s^-1 * w * t)
            # TODO(malmaud): See comment on 'act_op.to_quantized' above, which applies
            # here as well.
            weight_quantized = weight_op.to_quantized(w_scaled_rows,
                                                      dtype=input_dtype)
        else:
            weight_quantized = w_scaled_rows
            weight_scale = jnp.array(1.0, dtype=SCALE_DTYPE)

        # Use metadata context to annotate op metadata with quantization info
        lhs_prec = None if act_hparams is None else act_hparams.prec
        rhs_prec = None if weight_params is None else weight_params.prec

        # To decide whether to use an integer-domain dot operation, we first check
        # if the static quantization parameters are compatible with it by seeing if
        # they request that both inputs be quantized 8bits or less. Then check if
        # the dynamic parameters are compatible with it. ie, in a training run with
        # quantization enabled, are we past the activation start step yet.
        if lhs_prec is None or rhs_prec is None or lhs_prec > 8 or rhs_prec > 8:
            use_int8_to_int32_dot = False
        else:
            # is_act_quantized might be an instance of a Jax tracer instead of a
            # Python boolean since it is generally computed from a dynamic input to a
            # JITted Jax function. Thus we use '&' instead of 'and'.
            use_int8_to_int32_dot = prefer_int8_to_int32_dot & is_weight_quantized & is_act_quantized

        metadata_context = contextlib.suppress()
        with metadata_context:
            # Calculate matmul(...)
            out_quantized = dot_general_aqt(
                act_quantized,
                weight_quantized,
                dimension_numbers=dot_dimension_numbers,
                dot_precision=dot_precision,
                use_int8_to_int32_dot=use_int8_to_int32_dot)

        # Scale the columns of the matmul output by computing `matmul(...) * t^-1`
        # TODO(malmaud): Make it possible to return an unquantized matmul to support
        # disabling quantization during initial phase of training.
        #
        # We convert the return value back to input_dtype to ensure the output
        # tensor of quantized_dot has the same dtype as the input tensors to
        # quantized_dot. This explicit cast is necessary since if the inputs are
        # bf16, 'weight_scale' will still fp32 and so multipying out_quantized by
        # (1/weight_scale) will result in a fp32 tensor. We want to convert that
        # back to bf16.
        return (out_quantized * (1 / weight_scale)).astype(input_dtype)

    elif quant_type in (QuantType.fake_quant, QuantType.fake_quant_with_int):
        if quant_type == QuantType.fake_quant_with_int:
            fake_dependency = act
        # create a dependency on fake input to control constant folding
        else:
            fake_dependency = None

        quantized_type = quant_type.to_jax_type()
        w = QuantOps.create_weights_fake_quant(w,
                                               weight_params=weight_params,
                                               quantized_type=quantized_type,
                                               fake_dependency=fake_dependency)

        # TODO(shivaniagrawal): HParams currently allows act_hparams to be NONE.
        # Going forward we can change act_hparams to be required field where if
        # either `prec` or `bounds` is None will result in No activation
        # quantization.
        if act_hparams:
            act = QuantOps.create_inputs_fake_quant(
                act, hparams=act_hparams, get_bounds_params=get_bounds_params)

        metadata_context = contextlib.suppress()
        with metadata_context:
            out_quantized = lax.dot_general(
                act,
                w,
                dimension_numbers=dot_dimension_numbers,
                precision=dot_precision)
        return out_quantized
    else:
        raise RuntimeError(f'Unsupported quant_type {quant_type}')
Ejemplo n.º 7
0
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 *,
                 padding_mask,
                 key_padding_mask,
                 segmentation=None,
                 key_segmentation=None):
        """Applies multi-head dot product attention on the input data.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` or for self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or
        None for self-attention, inn which case key/values will be derived from
        inputs_q.
      padding_mask: boolean tensor specifying query tokens that are pad token.
      key_padding_mask: boolean tensor specifying key-value tokens that are pad
        token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """
        batch_size, query_sequence_length, channel_size = inputs_q.shape
        hparams = self.hparams
        if inputs_kv is None:
            inputs_kv = inputs_q
            key_sequence_length = inputs_q.shape[1]
        else:
            key_sequence_length = inputs_kv.shape[1]
            shape_utils.assert_shapes_equal(
                inputs_kv.shape,
                (batch_size, key_sequence_length, channel_size))

        jax_precision = jax.lax.Precision.DEFAULT

        if padding_mask is not None:
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1))
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        else:
            shape_utils.assert_shapes_equal(
                key_padding_mask.shape, (batch_size, key_sequence_length, 1))
        attention_axis = self.attention_axis
        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        qkv_features = self.qkv_features
        qkv_features = qkv_features or inputs_q.shape[-1]

        num_heads = self.num_heads
        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        paxis_name = self.paxis_name
        train = self.train
        kernel_init = self.kernel_init
        bias_init = self.bias_init
        use_bias = self.use_bias
        dtype = self.dtype

        def multi_batch_dense_aqt(inputs, *, name, padding_mask):
            batch_size, sequence_length, channel_size = inputs.shape
            inputs = inputs.reshape(batch_size * sequence_length, channel_size)
            if padding_mask is not None:
                padding_mask = padding_mask.reshape(
                    batch_size * sequence_length, 1)
            out = flax_layers.DenseAqt(name=name,
                                       features=num_heads * head_dim,
                                       paxis_name=paxis_name,
                                       train=train,
                                       quant_context=self.quant_context,
                                       hparams=hparams.dense_kqv,
                                       kernel_init=kernel_init,
                                       bias_init=bias_init,
                                       use_bias=use_bias,
                                       dtype=dtype)(inputs,
                                                    padding_mask=padding_mask)
            return out.reshape(batch_size, sequence_length, num_heads,
                               head_dim)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, sequence_length, n_heads, n_features_per_head]
        query = multi_batch_dense_aqt(inputs_q,
                                      name='query',
                                      padding_mask=padding_mask)
        key = multi_batch_dense_aqt(inputs_kv,
                                    name='key',
                                    padding_mask=key_padding_mask)
        value = multi_batch_dense_aqt(inputs_kv,
                                      name='value',
                                      padding_mask=key_padding_mask)
        is_cache_initialized = False
        if self.decode:
            is_cache_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_cache_initialized:
                expected_shape = list(cached_key.value.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                cshape = cached_key.value.shape
                indices = [0] * len(cshape)
                i = cache_index.value
                attn_size = onp.prod(onp.take(cshape, attention_axis))

                *batch_dims, max_length, num_heads, depth_per_head = (  # pylint: disable=unused-variable
                    cached_key.value.shape)
                indices = (0, ) * len(batch_dims) + (i, 0, 0)

                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                one = jnp.array(1, jnp.int32)
                cache_index.value = cache_index.value + one
                cached_key.value = key
                cached_value.value = value

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(max_length) < cache_index.value), cshape[:2])
                key_padding_mask = key_padding_mask.astype(
                    jnp.float32)[Ellipsis, None]

        # create attention masks
        mask_components = []
        if self.causal_mask:
            if self.decode and is_cache_initialized:
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(onp.take(key.shape, attention_axis))
                attn_size = onp.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.int32)
                mask = ii < cache_index.value
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))
        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            attn_padding_mask = make_padding_mask(
                padding_mask_query=padding_mask,
                padding_mask_key=key_padding_mask,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis)
            mask_components.append(attn_padding_mask)
        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)
        attention_mask = None
        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)
            attention_mask = attention_mask.astype(jnp.bool_)

            # attention mask in the form of attention bias
            attention_bias = jnp.where(
                attention_mask,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # Add an extra dimension to the mask corresponding to the head
        # dimension. eg, if inputs_q has shape [batch_size, sequence_length,
        # n_features], then padding_mask will have a shape
        # [batch_size, sequence_length, 1] and query will have shape
        # [batch_size, sequence_length, n_heads, n_features_per_head].
        # We create query_padding_mask with shape [batch_size, sequence_length,
        # 1, 1] to be broadcast-compatible with 'query'.
        if padding_mask is not None:
            padding_mask = padding_mask[Ellipsis, None]
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1, 1))
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[Ellipsis, None]
            # During prediction, the key padding mask is only going to be
            # broadcast-compatible with the key.
            shape_utils.assert_shapes_compatible(
                key_padding_mask.shape,
                (batch_size, key_sequence_length, 1, 1))

        # apply attention
        attention_fn = self.attention_fn
        dropout_rate = self.dropout_rate
        broadcast_dropout = self.broadcast_dropout
        deterministic = self.deterministic
        if not deterministic and self.dropout_rate > 0.0:
            dropout_rng = self.make_rng('dropout')
        else:
            dropout_rng = None
        x = attention_fn(  # pylint: disable=redundant-keyword-arg
            query=query,
            key=key,
            value=value,
            hparams=hparams.attn_acts,
            paxis_name=paxis_name,
            train=train,
            quant_context=self.quant_context,
            dtype=dtype,
            axis=attention_axis,
            bias=attention_bias,
            precision=jax_precision,
            dropout_rng=dropout_rng,
            dropout_rate=dropout_rate,
            broadcast_dropout=broadcast_dropout,
            deterministic=deterministic,
            query_padding_mask=padding_mask,
            key_padding_mask=key_padding_mask,
            attn_mask=attention_mask)
        shape_utils.assert_shapes_equal(
            x.shape, (batch_size, query_sequence_length, num_heads, head_dim))
        x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim)
        if padding_mask is not None:
            padding_mask = padding_mask.reshape(
                batch_size * query_sequence_length, 1)
        # back to the original inputs dimensions
        out = flax_layers.DenseAqt(features=channel_size,
                                   hparams=hparams.dense_out,
                                   quant_context=self.quant_context,
                                   paxis_name=paxis_name,
                                   train=train,
                                   kernel_init=kernel_init,
                                   bias_init=bias_init,
                                   use_bias=use_bias,
                                   dtype=dtype,
                                   name='dense_out')(x,
                                                     padding_mask=padding_mask)
        shape_utils.assert_shapes_equal(
            out.shape, (batch_size * query_sequence_length, channel_size))
        out = out.reshape(batch_size, query_sequence_length, channel_size)
        return out
Ejemplo n.º 8
0
def dot_product_attention(query,
                          key,
                          value,
                          hparams,
                          quant_context,
                          paxis_name,
                          train,
                          key_padding_mask,
                          query_padding_mask,
                          attn_mask,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs.


  Args:
    query: queries for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size,
      sequence_length, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size,
      sequence_length, num_heads, value_channels]`.
    hparams: hyperparameters used for quantization.
    quant_context: context for quantization.
    paxis_name: axis_name to which a user `pmaps` the parent module (model),
      refer to jax.pmap() for more documentation. This arg is used for
      get_bounds acts quantization (QuantOps.create_input_fake_quant)
    train: Whether model is training.
    key_padding_mask: boolean mask indicating which elements in 'key' and
      'value' are padding. Must have a shape compatible with 'key' and 'value'.
    query_padding_mask: boolean mask indicating which elements in `query` are
      padding (True means not padding).
    attn_mask: boolean mask indicating which elements of the calculated
      attention weight matrix should be used for collecting activation
      statistics. Should have a shape broadcast-compatible with '[bs,
      sequence_length, sequence_length]'. Must have a shape broadcast-compatible
      'query'.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout.
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, sequence_length, num_heads, value_channels]`.
  """
    batch_size, query_sequence_length, num_heads, channel_size = query.shape
    key_sequence_length = key.shape[1]
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, key_sequence_length, num_heads, channel_size))
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, key_sequence_length, num_heads, channel_size))
    if key_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1))
    if query_padding_mask is not None:
        shape_utils.assert_shapes_equal(
            query_padding_mask.shape,
            (batch_size, query_sequence_length, 1, 1))

    if attn_mask is not None:
        shape_utils.assert_shapes_compatible(
            attn_mask.shape,
            (batch_size, 1, query_sequence_length, key_sequence_length))

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )

    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)

    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        key.shape, (batch_size, num_heads, key_sequence_length, channel_size))

    key_padding_mask_transposed = None
    query_padding_mask_transposed = None
    if key_padding_mask is not None:
        key_padding_mask_transposed = key_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            key_padding_mask_transposed.shape,
            (batch_size, 1, key_sequence_length, 1))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_k',
                           update_stats=train)(
                               key, mask=key_padding_mask_transposed)

    if query_padding_mask is not None:
        query_padding_mask_transposed = query_padding_mask.transpose(qk_perm)
        shape_utils.assert_shapes_equal(
            query_padding_mask_transposed.shape,
            (batch_size, 1, query_sequence_length, 1))

    key_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=key_padding_mask_transposed,
        module_name='K')

    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)
    shape_utils.assert_shapes_equal(
        value.shape,
        (batch_size, num_heads, channel_size, key_sequence_length))
    value_padding_mask_transposed = None
    if key_padding_mask is not None:
        value_padding_mask_transposed = key_padding_mask.transpose(v_perm)
        shape_utils.assert_shapes_equal(
            value_padding_mask_transposed.shape,
            (batch_size, 1, 1, key_sequence_length))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_v',
                           update_stats=train)(
                               value, mask=value_padding_mask_transposed)

    value_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=value_padding_mask_transposed,
        module_name='V')

    query = query / jnp.sqrt(depth).astype(dtype)
    query = query.transpose(qk_perm)
    shape_utils.assert_shapes_equal(
        query.shape,
        (batch_size, num_heads, query_sequence_length, channel_size))

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_q',
                           update_stats=train)(
                               query, mask=query_padding_mask_transposed)

    query_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=query_padding_mask_transposed,
        module_name='Q')

    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = quantized_dynamic_dot_general(
        lhs_act=query,
        rhs_act=key,
        dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t,
                                                        batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_q,
        lhs_get_bounds_params=query_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_k,
        rhs_get_bounds_params=key_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we use
    # static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        attn_weights.shape,
        (batch_size, num_heads, query_sequence_length, key_sequence_length))

    # apply attention bias: masking, dropout, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    attn_weights = softmax(attn_weights,
                           norm_dims,
                           dtype,
                           hparams.softmax,
                           quant_context=quant_context)

    # apply dropout
    if not deterministic and dropout_rate > 0.0:
        if dropout_rng is None:
            raise ValueError(
                'dropout_rng cannot be None if dropout is requested.')
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if quant_context.collect_acts_stats:
        stats_tag.StatsTag(channel_axis=None,
                           name='attn_act_probs',
                           update_stats=train)(attn_weights, mask=attn_mask)

    if hparams.attn_act_probs is not None:
        assert hparams.attn_act_probs.bounds == 1.0, (
            'act quantization bounds should '
            'be set to fix value 1.0 to '
            'match Softmax range.')
    probs_get_bounds_params = get_bounds.GetBounds.Params(
        update_bounds=quant_context.update_bounds,
        update_stats=train,
        paxis_name=paxis_name,
        mask=attn_mask,
        module_name='attn_probs')

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = quantized_dynamic_dot_general(
        lhs_act=attn_weights,
        rhs_act=value,
        dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t,
                                                     batch_dims_t)),
        dot_precision=precision,
        quant_type=hparams.quant_type,
        lhs_act_hparams=hparams.attn_act_probs,
        lhs_get_bounds_params=probs_get_bounds_params,
        rhs_act_hparams=hparams.attn_act_v,
        rhs_get_bounds_params=value_get_bounds_params,
    )
    # NOTE(shivaniagrawal): we do per-layer quantization here since that's the
    # only way for activation*activation matmuls to be aqt compatible since we
    # use static scaling factors for activations.

    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, num_heads, query_sequence_length, channel_size))
    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    shape_utils.assert_shapes_equal(
        y.shape, (batch_size, query_sequence_length, num_heads, channel_size))
    return y
Ejemplo n.º 9
0
    def __call__(
        self,
        encoded,
        src_padding_mask,
        targets,
        targets_positions=None,
        inputs_segmentation=None,
        targets_segmentation=None,
        tgt_padding_mask=None,
    ):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      src_padding_mask: padding mask for inputs.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      tgt_padding_mask: target tokens padding mask.

    Returns:
      output of a transformer decoder.

    """
        batch_size, sequence_length, channel_size = encoded.shape  # pylint: disable=unused-variable
        target_batch_size, target_sequence_length = targets.shape  # pylint: disable=unused-variable
        shape_utils.assert_shapes_equal(targets.shape,
                                        (batch_size, target_sequence_length))

        # Padding Masks
        if tgt_padding_mask is None:
            tgt_padding_mask = (targets > 0)[Ellipsis, None]
        shape_utils.assert_shapes_equal(
            tgt_padding_mask.shape, (batch_size, target_sequence_length, 1))

        if self.use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = aqt_flax_layers.EmbedAqt(
                num_embeddings=self.output_vocab_size,
                features=self.emb_dim,
                hparams=self.hparams.embedding,
                embedding_init=nn.initializers.normal(
                    stddev=self.emb_dim**-0.5),
                dtype=dtype,
                name='target_embed',
                train=self.train,
                quant_context=self.quant_context,
                paxis_name='batch')
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not self.decode:
            y = shift_right(y)
        y = output_embed(y) * jnp.sqrt(self.emb_dim)
        y = AddPositionEmbs(name='posembed_targets',
                            max_len=self.max_len,
                            decode=self.decode,
                            min_timescale=1.0,
                            max_timescale=10000.0)(
                                y, inputs_positions=targets_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not self.train)

        if self.use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        num_layers = len(self.hparams.encoder_decoder_1d_blocks)
        for lyr in range(num_layers):
            y = EncoderDecoder1DBlock(
                train=self.train,
                quant_context=self.quant_context,
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                hparams=self.hparams.encoder_decoder_1d_blocks[lyr],
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                deterministic=not self.train,
                name=f'encoderdecoderblock_{lyr}',
                decode=self.decode)(y,
                                    encoded,
                                    padding_mask=tgt_padding_mask,
                                    key_padding_mask=src_padding_mask,
                                    inputs_segmentation=inputs_segmentation,
                                    targets_segmentation=targets_segmentation)
        y = aqt_flax_layers.LayerNormAqt(dtype=dtype,
                                         name='encoderdecoder_norm',
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(y)
        y = y.reshape((batch_size * target_sequence_length, channel_size))
        tgt_padding_mask = tgt_padding_mask.reshape(
            (batch_size * target_sequence_length, 1))
        # Decoded Logits
        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(query=y,
                                         padding_mask=tgt_padding_mask,
                                         paxis_name=self.paxis_name,
                                         train=self.train)
        else:
            if self.hparams.logits is None:
                raise ValueError(
                    'If logits_via_embedding is False, then the hparams '
                    'for the logits layer have to be provided.')
            logits = aqt_flax_layers.DenseAqt(
                features=self.output_vocab_size,
                dtype=dtype,
                paxis_name='batch',
                train=self.train,
                quant_context=self.quant_context,
                hparams=self.hparams.logits,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6),
                name='logits_dense')(y, padding_mask=tgt_padding_mask)
        return logits
Ejemplo n.º 10
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.

    """
        batch_size, sequence_length = inputs.shape

        # Padding Masks
        src_padding_mask = (inputs > 0)[Ellipsis, None]
        shape_utils.assert_shapes_equal(src_padding_mask.shape,
                                        (batch_size, sequence_length, 1))

        if self.use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Embedding
        if self.shared_embedding is None:
            input_embed = aqt_flax_layers.EmbedAqt(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                hparams=self.hparams.embedding,
                embedding_init=nn.initializers.normal(
                    stddev=self.emb_dim**-0.5),
                dtype=dtype,
                name='input_embed',
                paxis_name='batch',
                train=self.train,
                quant_context=self.quant_context)
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x) * jnp.sqrt(self.emb_dim)
        x = AddPositionEmbs(name='posembed_input',
                            max_len=self.max_len,
                            min_timescale=1.0,
                            max_timescale=10000.0,
                            decode=False)(x, inputs_positions=inputs_positions)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not self.train)

        if self.use_bfloat16:
            x = x.astype(jnp.bfloat16)

        # Input Encoder
        num_layers = len(self.hparams.encoder_1d_blocks)
        for lyr in range(num_layers):
            x = Encoder1DBlock(
                train=self.train,
                quant_context=self.quant_context,
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                hparams=self.hparams.encoder_1d_blocks[lyr],
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                deterministic=not self.train,
                name=f'encoderblock_{lyr}')(
                    x,
                    padding_mask=src_padding_mask,
                    inputs_segmentation=inputs_segmentation)
        encoded = aqt_flax_layers.LayerNormAqt(
            dtype=dtype,
            name='encoder_norm',
            hparams=self.hparams.layer_norm,
            quant_context=self.quant_context)(x)
        shape_utils.assert_shapes_equal(
            encoded.shape, (batch_size, sequence_length, self.emb_dim))
        return encoded
Ejemplo n.º 11
0
    def __call__(
        self,
        targets,
        encoded,
        padding_mask,
        key_padding_mask,
        inputs_segmentation=None,
        targets_segmentation=None,
    ):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      padding_mask: bool, mask padding tokens
      key_padding_mask: bool, mask padding tokens
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      output after transformer block.
    """

        # Decoder block.
        batch_size, sequence_length, num_channels = targets.shape
        encoded_sequence_length = encoded.shape[1]
        shape_utils.assert_shapes_equal(padding_mask.shape,
                                        (batch_size, sequence_length, 1))
        shape_utils.assert_shapes_equal(
            encoded.shape, (batch_size, encoded_sequence_length, num_channels))
        shape_utils.assert_shapes_equal(
            key_padding_mask.shape, (batch_size, encoded_sequence_length, 1))

        x = aqt_flax_layers.LayerNormAqt(
            dtype=self.dtype,
            hparams=self.hparams.layer_norm,
            quant_context=self.quant_context)(targets)
        x = aqt_flax_attention.SelfAttentionAqt(
            hparams=self.hparams.self_attention,
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            attention_axis=(1, ),
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            causal_mask=True,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate,
            deterministic=self.deterministic,
            name='dec_self_att',
            decode=self.decode)(
                x,
                padding_mask=padding_mask,
                segmentation=targets_segmentation,
            )
        x = nn.Dropout(rate=self.dropout_rate)(
            x, deterministic=self.deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = aqt_flax_layers.LayerNormAqt(dtype=self.dtype,
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(x)
        y = aqt_flax_attention.MultiHeadDotProductAttentionAqt(
            hparams=self.hparams.enc_dec_attention,
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            attention_axis=(1, ),
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            causal_mask=False,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate,
            deterministic=self.deterministic,
            decode=False,
            name='dec_enc_att')(inputs_q=y,
                                inputs_kv=encoded,
                                padding_mask=padding_mask,
                                key_padding_mask=key_padding_mask,
                                segmentation=targets_segmentation,
                                key_segmentation=inputs_segmentation)
        y = nn.Dropout(rate=self.dropout_rate)(
            y, deterministic=self.deterministic)
        y = y + x

        # MLP block.
        z = aqt_flax_layers.LayerNormAqt(dtype=self.dtype,
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(y)
        z = MlpBlock(
            mlp_dim=self.mlp_dim,
            dtype=self.dtype,
            hparams=self.hparams.mlp_block,
            # inputs would be signed, called after attention layer.
            train=self.train,
            quant_context=self.quant_context,
            dropout_rate=self.dropout_rate,
            deterministic=self.deterministic,
            name='mlp_block')(z, padding_mask=padding_mask)

        return y + z
Ejemplo n.º 12
0
    def __call__(self, inputs, padding_mask, inputs_segmentation=None):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data
      padding_mask: bool, mask padding tokens
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output after transformer block.
    """

        # Attention block.
        batch_size, sequence_length, channel_size = inputs.shape
        shape_utils.assert_shapes_equal(padding_mask.shape,
                                        (batch_size, sequence_length, 1))
        x = aqt_flax_layers.LayerNormAqt(
            dtype=self.dtype,
            hparams=self.hparams.layer_norm,
            quant_context=self.quant_context)(inputs)
        x = aqt_flax_attention.SelfAttentionAqt(
            hparams=self.hparams.attention,
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            attention_axis=(1, ),
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            causal_mask=False,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate,
            deterministic=self.deterministic,
            decode=False,
            name='enc_self_att')(x,
                                 padding_mask=padding_mask,
                                 segmentation=inputs_segmentation)

        x = nn.Dropout(rate=self.dropout_rate)(
            x, deterministic=self.deterministic)
        x = x + inputs

        # MLP block.
        y = aqt_flax_layers.LayerNormAqt(dtype=self.dtype,
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(x)
        y = MlpBlock(
            mlp_dim=self.mlp_dim,
            hparams=self.hparams.mlp_block,
            # inputs would be signed, called after attention layer.
            train=self.train,
            quant_context=self.quant_context,
            dtype=self.dtype,
            dropout_rate=self.dropout_rate,
            deterministic=self.deterministic,
            name='mlp_block')(y, padding_mask=padding_mask)
        out = x + y
        shape_utils.assert_shapes_equal(
            out.shape, (batch_size, sequence_length, channel_size))
        return out