def test_activation(self):
        # Create a model that does not use an activation.
        no_activation_layer = dense_einsum.DenseEinsum(output_shape=64,
                                                       num_summed_dimensions=1,
                                                       activation=None)
        input_tensor = tf.keras.Input(shape=(None, 80))
        output_tensor = no_activation_layer(input_tensor)
        no_activation_model = tf.keras.Model(input_tensor, output_tensor)

        # Create a model that uses a softmax activation.
        activation_layer = dense_einsum.DenseEinsum(output_shape=64,
                                                    num_summed_dimensions=1,
                                                    activation="softmax")
        input_tensor = tf.keras.Input(shape=(None, 80))
        output_tensor = activation_layer(input_tensor)
        activation_model = tf.keras.Model(input_tensor, output_tensor)

        # Make sure the models' weights are identical.
        activation_model.set_weights(no_activation_model.get_weights())

        # Predict using each model on the same input data. The output should be
        # different, since one is using a softmax - even though the models' weights
        # are the same.
        input_values = 10 * np.random.random_sample((10, 4, 80))
        non_activated_data = no_activation_model.predict(input_values)
        activated_data = activation_model.predict(input_values)
        self.assertNotAllClose(activated_data, non_activated_data)
Beispiel #2
0
  def __init__(self,
               num_heads,
               head_size,
               dropout_rate=0.0,
               kernel_initializer="glorot_uniform",
               bias_initializer="zeros",
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
    super(MultiHeadAttention, self).__init__(**kwargs)
    self._num_heads = num_heads
    self._head_size = head_size
    self._dropout_rate = dropout_rate
    self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self._bias_initializer = tf.keras.initializers.get(bias_initializer)
    self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
    self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
    self._bias_constraint = tf.keras.constraints.get(bias_constraint)

    self._query_dense = dense_einsum.DenseEinsum(
        output_shape=(self._num_heads, self._head_size),
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="query")

    self._key_dense = dense_einsum.DenseEinsum(
        output_shape=(self._num_heads, self._head_size),
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="key")

    self._value_dense = dense_einsum.DenseEinsum(
        output_shape=(self._num_heads, self._head_size),
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="value")

    self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])

    self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
Beispiel #3
0
    def build(self, input_shape):
        # Self attention.
        self.self_attention = attention.CachedAttention(
            num_heads=self.num_attention_heads,
            key_size=self.attention_head_size,
            dropout=self.attention_probs_dropout_prob,
            kernel_initializer=self._kernel_initializer,
            name="self_attention")
        self.self_attention_output_dense = dense_einsum.DenseEinsum(
            output_shape=self.hidden_size,
            num_summed_dimensions=2,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            name="self_attention_output")
        self.self_attention_dropout = tf.keras.layers.Dropout(
            rate=self.hidden_dropout_prob)
        self.self_attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
        # Encoder-decoder attention.
        self.encdec_attention = self._cross_attention_cls(
            num_heads=self.num_attention_heads,
            key_size=self.attention_head_size,
            dropout=self.attention_probs_dropout_prob,
            output_shape=self.hidden_size,
            kernel_initializer=self._kernel_initializer,
            name="attention/encdec")

        self.encdec_attention_dropout = tf.keras.layers.Dropout(
            rate=self.hidden_dropout_prob)
        self.encdec_attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))

        # Feed-forward projection.
        self.intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self.intermediate_size,
            activation=None,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            name="intermediate")
        self.intermediate_activation_layer = tf.keras.layers.Activation(
            self.intermediate_activation)
        self.output_dense = dense_einsum.DenseEinsum(
            output_shape=self.hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            name="output")
        self.output_dropout = tf.keras.layers.Dropout(
            rate=self.hidden_dropout_prob)
        self.output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm", axis=-1, epsilon=1e-12)
        super(TransformerDecoderLayer, self).build(input_shape)
Beispiel #4
0
 def build(self, input_shape):
     if self._output_shape:
         output_shape = self._output_shape
     else:
         input_shape = tf.TensorShape(input_shape[0])
         output_shape = input_shape[-1]
     self._output_dense = dense_einsum.DenseEinsum(
         output_shape=output_shape,
         num_summed_dimensions=2,
         kernel_initializer=self._kernel_initializer,
         bias_initializer=self._bias_initializer,
         kernel_regularizer=self._kernel_regularizer,
         bias_regularizer=self._bias_regularizer,
         activity_regularizer=self._activity_regularizer,
         kernel_constraint=self._kernel_constraint,
         bias_constraint=self._bias_constraint,
         name="attention_output")
     self._pre_softmax_weight = self.add_weight(
         "pre_softmax_weight",
         shape=(self._num_heads, self._num_heads),
         initializer=self._kernel_initializer,
         regularizer=self._kernel_regularizer,
         constraint=self._kernel_constraint,
         dtype=self.dtype,
         trainable=True)
     self._post_softmax_weight = self.add_weight(
         "post_softmax_weight",
         shape=(self._num_heads, self._num_heads),
         initializer=self._kernel_initializer,
         regularizer=self._kernel_regularizer,
         constraint=self._kernel_constraint,
         dtype=self.dtype,
         trainable=True)
     super(TalkingHeadsAttention, self).build(input_shape)
    def test_bias_term_can_be_disabled(self):
        # A layer created using the bias should have two weights.
        test_layer = dense_einsum.DenseEinsum(output_shape=64,
                                              num_summed_dimensions=1,
                                              use_bias=True)
        input_tensor = tf.keras.Input(shape=(None, 80))
        _ = test_layer(input_tensor)
        self.assertEqual(2, len(test_layer.get_weights()))

        # A layer created without the bias should have only one weight.
        test_layer = dense_einsum.DenseEinsum(output_shape=64,
                                              num_summed_dimensions=1,
                                              use_bias=False)
        input_tensor = tf.keras.Input(shape=(None, 80))
        _ = test_layer(input_tensor)
        self.assertEqual(1, len(test_layer.get_weights()))
 def test_non_iterable_output_shape(self):
     test_layer = dense_einsum.DenseEinsum(output_shape=64,
                                           num_summed_dimensions=1)
     # Create a 3-dimensional input (the first dimension is implicit).
     input_tensor = tf.keras.Input(shape=(None, 80))
     _ = test_layer(input_tensor)
     self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
     self.assertEqual(test_layer._kernel_shape, (80, 64))
 def test_3D_einsum_with_one_bound_dimensions(self):
     test_layer = dense_einsum.DenseEinsum(output_shape=(64, 32),
                                           num_summed_dimensions=1)
     # Create a 3-dimensional input (the first dimension is implicit).
     input_tensor = tf.keras.Input(shape=(None, 80))
     _ = test_layer(input_tensor)
     self.assertEqual(test_layer._einsum_string, "abc,cde->abde")
     self.assertEqual(test_layer._kernel_shape, (80, 64, 32))
 def test_with_explicit_initializer(self):
     test_layer = dense_einsum.DenseEinsum(
         output_shape=(64, ),
         num_summed_dimensions=2,
         kernel_initializer=tf.keras.initializers.TruncatedNormal(
             stddev=0.02))
     # Create a 4-dimensional input (the first dimension is implicit).
     input_tensor = tf.keras.Input(shape=(None, 40, 80))
     _ = test_layer(input_tensor)
     self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
     self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
 def build(self, unused_input_shapes):
   self._query_dense = dense_einsum.DenseEinsum(
       output_shape=(self._num_heads, self._head_size),
       kernel_initializer=self._kernel_initializer,
       bias_initializer=self._bias_initializer,
       kernel_regularizer=self._kernel_regularizer,
       bias_regularizer=self._bias_regularizer,
       activity_regularizer=self._activity_regularizer,
       kernel_constraint=self._kernel_constraint,
       bias_constraint=self._bias_constraint,
       dtype=self.dtype,
       name="encdocatt_query")
   self._key_dense = dense_einsum.DenseEinsum(
       output_shape=(self._num_heads, self._head_size),
       kernel_initializer=self._kernel_initializer,
       bias_initializer=self._bias_initializer,
       kernel_regularizer=self._kernel_regularizer,
       bias_regularizer=self._bias_regularizer,
       activity_regularizer=self._activity_regularizer,
       kernel_constraint=self._kernel_constraint,
       bias_constraint=self._bias_constraint,
       dtype=self.dtype,
       name="encdocatt_key")
   super(VotingAttention, self).build(unused_input_shapes)
Beispiel #10
0
 def build(self, input_shape):
     if self._output_shape:
         output_shape = self._output_shape
     else:
         input_shape = tf.TensorShape(input_shape[0])
         output_shape = input_shape[-1]
     self._output_dense = dense_einsum.DenseEinsum(
         output_shape=output_shape,
         num_summed_dimensions=2,
         kernel_initializer=self._kernel_initializer,
         bias_initializer=self._bias_initializer,
         kernel_regularizer=self._kernel_regularizer,
         bias_regularizer=self._bias_regularizer,
         activity_regularizer=self._activity_regularizer,
         kernel_constraint=self._kernel_constraint,
         bias_constraint=self._bias_constraint,
         name="attention_output")
     super(MultiHeadAttention, self).build(input_shape)
Beispiel #11
0
    def build(self, input_shape):
        input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape) != 3:
            raise ValueError(
                "TransformerLayer expects a three-dimensional input of "
                "shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if len(input_shape) == 2:
            mask_tensor_shape = tf.TensorShape(input_shape[1])
            expected_mask_tensor_shape = tf.TensorShape(
                [batch_size, sequence_length, sequence_length])
            if not expected_mask_tensor_shape.is_compatible_with(
                    mask_tensor_shape):
                raise ValueError(
                    "When passing a mask tensor to TransformerLayer, the "
                    "mask tensor must be of shape [batch, "
                    "sequence_length, sequence_length] (here %s). Got a "
                    "mask tensor of shape %s." %
                    (expected_mask_tensor_shape, mask_tensor_shape))
        if hidden_size % self._num_heads != 0:
            raise ValueError(
                "The input size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, self._num_heads))
        self._attention_head_size = int(hidden_size // self._num_heads)

        self._attention_layer = attention.MultiHeadAttention(
            num_heads=self._num_heads,
            key_size=self._attention_head_size,
            dropout=self._attention_dropout_rate,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention")
        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._dropout_rate)
        if self._use_layer_norm:
            # Use float32 in layernorm for numeric stability.
            # It is probably safe in mixed_float16, but we haven't validated this yet.
            self._attention_layer_norm = (tf.keras.layers.LayerNormalization(
                name="self_attention_layer_norm",
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32))
        self._intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self._intermediate_size,
            activation=None,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="intermediate")
        self._intermediate_activation_layer = tf.keras.layers.Activation(
            self._intermediate_activation)
        self._output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="output")
        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
        if self._use_layer_norm:
            # Use float32 in layernorm for numeric stability.
            self._output_layer_norm = tf.keras.layers.LayerNormalization(
                name="output_layer_norm",
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32)

        self._rezero_a = self.add_weight(
            name="rezero_alpha",
            initializer=tf.keras.initializers.Zeros(),
            trainable=True,
            dtype=tf.float32)

        super(ReZeroTransformer, self).build(input_shape)
Beispiel #12
0
  def build(self, input_shape):
    input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
    input_tensor_shape = tf.TensorShape(input_tensor)
    if len(input_tensor_shape) != 3:
      raise ValueError(
          "TransformerScaffold expects a three-dimensional input of "
          "shape [batch, sequence, width].")
    batch_size, sequence_length, hidden_size = input_tensor_shape

    if len(input_shape) == 2:
      mask_tensor_shape = tf.TensorShape(input_shape[1])
      expected_mask_tensor_shape = tf.TensorShape(
          [batch_size, sequence_length, sequence_length])
      if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
        raise ValueError("When passing a mask tensor to TransformerLayer, the "
                         "mask tensor must be of shape [batch, "
                         "sequence_length, sequence_length] (here %s). Got a "
                         "mask tensor of shape %s." %
                         (expected_mask_tensor_shape, mask_tensor_shape))
    if hidden_size % self._num_heads != 0:
      raise ValueError(
          "The input size (%d) is not a multiple of the number of attention "
          "heads (%d)" % (hidden_size, self._num_heads))
    self._attention_head_size = int(hidden_size // self._num_heads)

    if isinstance(self._attention_cls, tf.keras.layers.Layer):
      self._attention_layer = self._attention_cls
    else:
      if self._attention_cfg is None:
        attention_cfg = {
            "num_heads": self._num_heads,
            "head_size": self._attention_head_size,
            "dropout_rate": self._attention_dropout_rate,
            "kernel_initializer": self._kernel_initializer,
            "bias_initializer": self._bias_initializer,
            "kernel_regularizer": self._kernel_regularizer,
            "bias_regularizer": self._bias_regularizer,
            "activity_regularizer": self._activity_regularizer,
            "kernel_constraint": self._kernel_constraint,
            "bias_constraint": self._bias_constraint,
            "name": "self_attention"
        }
      else:
        attention_cfg = self._attention_cfg
      self._attention_layer = self._attention_cls(**attention_cfg)

    self._attention_output_dense = dense_einsum.DenseEinsum(
        output_shape=hidden_size,
        num_summed_dimensions=2,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="self_attention_output")
    self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    self._attention_layer_norm = (
        tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
            dtype=tf.float32))
    self._intermediate_dense = dense_einsum.DenseEinsum(
        output_shape=self._intermediate_size,
        activation=self._intermediate_activation,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        dtype=tf.float32,  # This layer is always float32 for numeric stability.
        name="intermediate")
    self._output_dense = dense_einsum.DenseEinsum(
        output_shape=hidden_size,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint,
        name="output")
    self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    self._output_layer_norm = tf.keras.layers.LayerNormalization(
        name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)

    super(TransformerScaffold, self).build(input_shape)
Beispiel #13
0
    def build(self, input_shape):
        input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape) != 3:
            raise ValueError(
                "TransformerLayer expects a three-dimensional input of "
                "shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if len(input_shape) == 2:
            mask_tensor_shape = tf.TensorShape(input_shape[1])
            expected_mask_tensor_shape = tf.TensorShape(
                [batch_size, sequence_length, sequence_length])
            if not expected_mask_tensor_shape.is_compatible_with(
                    mask_tensor_shape):
                raise ValueError(
                    "When passing a mask tensor to TransformerLayer, the "
                    "mask tensor must be of shape [batch, "
                    "sequence_length, sequence_length] (here %s). Got a "
                    "mask tensor of shape %s." %
                    (expected_mask_tensor_shape, mask_tensor_shape))
        if hidden_size % self._num_heads != 0:
            raise ValueError(
                "The input size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, self._num_heads))
        self._attention_head_size = int(hidden_size // self._num_heads)

        self._attention_layer = attention.MultiHeadAttention(
            num_heads=self._num_heads,
            key_size=self._attention_head_size,
            dropout=self._attention_dropout_rate,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention")
        # pylint: disable=protected-access
        self._attention_layer.build([input_tensor_shape] * 3)
        self._attention_output_dense = self._attention_layer._output_dense
        # pylint: enable=protected-access
        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._dropout_rate)
        # Use float32 in layernorm for numeric stability.
        # It is probably safe in mixed_float16, but we haven't validated this yet.
        self._attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32))
        self._intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self._intermediate_size,
            activation=None,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="intermediate")
        policy = tf.keras.mixed_precision.experimental.global_policy()
        if policy.name == "mixed_bfloat16":
            # bfloat16 causes BERT with the LAMB optimizer to not converge
            # as well, so we use float32.
            # TODO(b/154538392): Investigate this.
            policy = tf.float32
        self._intermediate_activation_layer = tf.keras.layers.Activation(
            self._intermediate_activation, dtype=policy)
        self._output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="output")
        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
        # Use float32 in layernorm for numeric stability.
        self._output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)

        super(Transformer, self).build(input_shape)
Beispiel #14
0
    def build(self, input_shape):
        target_tensor_shape = tf.TensorShape(input_shape[0])
        if len(target_tensor_shape) != 3:
            raise ValueError(
                "TransformerLayer expects a three-dimensional input of "
                "shape [batch, sequence, width].")
        hidden_size = target_tensor_shape[2]
        if hidden_size % self.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, self.num_attention_heads))
        self.attention_head_size = int(hidden_size / self.num_attention_heads)
        # Self attention.
        self.self_attention = attention.CachedAttention(
            num_heads=self.num_attention_heads,
            key_size=self.attention_head_size,
            dropout=self.attention_dropout_rate,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention")
        self.self_attention_output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            num_summed_dimensions=2,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention_output")
        self.self_attention_dropout = tf.keras.layers.Dropout(
            rate=self.dropout_rate)
        self.self_attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
        # Encoder-decoder attention.
        self.encdec_attention = self._cross_attention_cls(
            num_heads=self.num_attention_heads,
            key_size=self.attention_head_size,
            dropout=self.attention_dropout_rate,
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="attention/encdec")

        self.encdec_attention_dropout = tf.keras.layers.Dropout(
            rate=self.dropout_rate)
        self.encdec_attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))

        # Feed-forward projection.
        self.intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self.intermediate_size,
            activation=None,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="intermediate")
        self.intermediate_activation_layer = tf.keras.layers.Activation(
            self.intermediate_activation)
        self.output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="output")
        self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
        self.output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm", axis=-1, epsilon=1e-12)
        super(TransformerDecoderLayer, self).build(input_shape)