示例#1
0
    def __init__(
        self,
        num_heads,
        head_size,
        dropout_rate=0.0,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        name="attention",
        **kwargs,
    ):
        kwargs["name"] = name
        super(GPT2Attention, self).__init__(**kwargs)
        self._num_heads = num_heads
        self._head_size = head_size
        self._dropout_rate = dropout_rate
        self._kernel_initializer = tf.keras.initializers.get(
            kernel_initializer)
        self._bias_initializer = tf.keras.initializers.get(bias_initializer)
        self._kernel_regularizer = tf.keras.regularizers.get(
            kernel_regularizer)
        self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self._bias_constraint = tf.keras.constraints.get(bias_constraint)

        # GPT2 project [batch x sequence x embedding] ---> [batch x sequence x embedding * 3]
        # 3 is for Q,K,V

        self._project_qkv = dense_einsum.DenseEinsum(
            output_shape=(3 * self._num_heads * self._head_size),
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="qkv",
        )

        self._masked_softmax = masked_softmax.MaskedSoftmax(
            mask_expansion_axes=[1])
        self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    def __init__(
        self,
        num_heads,
        head_size,
        dropout_rate=0.0,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        name="attention",
        **kwargs,
    ):
        """
        Args:
            num_heads: Number of attention heads.
            head_size: Size of each attention head.
            dropout: Dropout probability.
            kernel_initializer: Initializer for dense layer kernels.
            bias_initializer: Initializer for dense layer biases.
            kernel_regularizer: Regularizer for dense layer kernels.
            bias_regularizer: Regularizer for dense layer biases.
            activity_regularizer: Regularizer for dense layer activity.
            kernel_constraint: Constraint for dense layer kernels.
            bias_constraint: Constraint for dense layer kernels.
        """
        kwargs["name"] = name
        super(BlockMultiHeadAttention, self).__init__(**kwargs)
        self._num_heads = num_heads
        self._head_size = head_size
        self._dropout_rate = dropout_rate
        self._kernel_initializer = tf.keras.initializers.get(
            kernel_initializer)
        self._bias_initializer = tf.keras.initializers.get(bias_initializer)
        self._kernel_regularizer = tf.keras.regularizers.get(
            kernel_regularizer)
        self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self._bias_constraint = tf.keras.constraints.get(bias_constraint)

        self._query_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="query",
        )

        self._key_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="key",
        )

        self._value_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="value",
        )

        self._masked_softmax = masked_softmax.MaskedSoftmax(
            mask_expansion_axes=[1])
        self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
示例#3
0
    def __init__(
        self,
        num_heads,
        head_size,
        bidirectional,
        create_positonal_embedding=True,
        positional_buckets=32,
        dropout_rate=0.0,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        use_bias=True,
        name="attention",
        is_cross_attention=False,
        **kwargs,
    ):
        """
        Args:
            num_heads: Number of attention heads.
            head_size: Size of each attention head.
            bidirectional: bool, based on masking
            create_positonal_embedding: bool, to create positional embedding
                                        (T5 creates only it at layer1)
            positional_buckets: Positional buckets
            dropout: Dropout probability.
            kernel_initializer: Initializer for dense layer kernels.
            bias_initializer: Initializer for dense layer biases.
            kernel_regularizer: Regularizer for dense layer kernels.
            bias_regularizer: Regularizer for dense layer biases.
            activity_regularizer: Regularizer for dense layer activity.
            kernel_constraint: Constraint for dense layer kernels.
            bias_constraint: Constraint for dense layer kernels.
        """
        kwargs["name"] = name
        super(T5Attention, self).__init__(**kwargs)
        self._num_heads = num_heads
        self._head_size = head_size
        self._bidirectional = bidirectional
        self._create_positonal_embedding = create_positonal_embedding
        self._positional_buckets = positional_buckets
        self._dropout_rate = dropout_rate
        self._kernel_initializer = tf.keras.initializers.get(
            kernel_initializer)
        self._bias_initializer = tf.keras.initializers.get(bias_initializer)
        self._kernel_regularizer = tf.keras.regularizers.get(
            kernel_regularizer)
        self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self._bias_constraint = tf.keras.constraints.get(bias_constraint)
        self._use_bias = use_bias
        self._cross_layer = is_cross_attention

        if self._create_positonal_embedding:
            # self._relative_embedding = tf.keras.layers.Embedding(
            #             self._positional_buckets, self._num_heads, name="relative_attention_bias")

            if self._cross_layer is True:
                self._relative_embedding = OnDeviceEmbedding(
                    self._positional_buckets,
                    self._num_heads,
                    trainable=False,
                    name="relative_attention_bias",
                )
            else:
                self._relative_embedding = OnDeviceEmbedding(
                    self._positional_buckets,
                    self._num_heads,
                    name="relative_attention_bias",
                )

        self._query_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="query",
            use_bias=use_bias,
        )

        self._key_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="key",
            use_bias=use_bias,
        )

        self._value_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="value",
            use_bias=use_bias,
        )

        self._masked_softmax = masked_softmax.MaskedSoftmax(
            mask_expansion_axes=[1])
        self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
示例#4
0
    def build(self, input_shape):
        """
        Args:
            input_shape: [word_embeddings (3D), attention_mask(3D)]
        """
        input_tensor = input_shape[0]
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape) != 3:
            raise ValueError("TransformerBERT expects a three-dimensional input of " "shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if not self._attention_head_size:
            if hidden_size % self._num_heads != 0:
                raise ValueError(
                    "The input size (%d) is not a multiple of the number of attention "
                    "heads (%d)" % (hidden_size, self._num_heads)
                )
            self._attention_head_size = int(hidden_size // self._num_heads)

        if self._attention_type == "full_attention":

            self._attention_layer = MultiHeadAttention(
                num_heads=self._num_heads,
                head_size=self._attention_head_size,
                dropout_rate=self._attention_dropout_rate,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                is_training=self.is_training,
                use_dropout=self.use_dropout,
                name="self_attention",
            )

        if self._attention_type == "block_attention":

            self._attention_layer = BlockMultiHeadAttention(
                num_heads=self._num_heads,
                head_size=self._attention_head_size,
                dropout_rate=self._attention_dropout_rate,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                is_training=self.is_training,
                use_dropout=self.use_dropout,
                name="self_attention",
            )

        if self._attention_type == "bigbird":
            self._attention_layer = BigBirdAttention(
                num_heads=self._num_heads,
                head_size=self._attention_head_size,
                num_rand_blocks=self._num_rand_blocks,
                from_block_size=self._from_block_size,
                to_block_size=self._to_block_size,
                dropout_rate=self._attention_dropout_rate,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                is_training=self.is_training,
                use_dropout=self.use_dropout,
                name="self_attention",
            )

        self._attention_output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention_output",
        )

        self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)

        self._attention_layer_norm = tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=self._layer_norm_epsilon,
            dtype=tf.float32,
        )

        # If we have cross attention inside encoder
        if self._cross_attention_inside_encoder:

            # Hard setting is_training to True, as we do not have to use cache here
            self._cross_attention_layer = MultiHeadAttention(
                num_heads=self._num_heads,
                head_size=self._attention_head_size,
                dropout_rate=self._attention_dropout_rate,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                is_training=self.is_training,
                use_dropout=self.use_dropout,
                name="cross_attention",
            )

            self._cross_attention_output_dense = dense_einsum.DenseEinsum(
                output_shape=hidden_size,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                name="cross_attention_output",
            )

            self._cross_attention_layer_norm = tf.keras.layers.LayerNormalization(
                name="cross_attention_layer_norm",
                axis=-1,
                epsilon=self._layer_norm_epsilon,
                dtype=tf.float32,
            )

        if self._is_decoder:
            if self._share_attention_layers:
                self._cross_attention_layer = self._attention_layer
                self._cross_attention_output_dense = self._attention_output_dense
                self._cross_attention_layer_norm = self._attention_layer_norm

            else:
                # Hard setting is_training to True, as we do not have to use cache here
                self._cross_attention_layer = MultiHeadAttention(
                    num_heads=self._num_heads,
                    head_size=self._attention_head_size,
                    dropout_rate=self._attention_dropout_rate,
                    kernel_initializer=self._kernel_initializer,
                    bias_initializer=self._bias_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    bias_regularizer=self._bias_regularizer,
                    activity_regularizer=self._activity_regularizer,
                    kernel_constraint=self._kernel_constraint,
                    bias_constraint=self._bias_constraint,
                    is_training=self.is_training,
                    use_dropout=self.use_dropout,
                    name="cross_attention",
                )

                self._cross_attention_output_dense = dense_einsum.DenseEinsum(
                    output_shape=hidden_size,
                    kernel_initializer=self._kernel_initializer,
                    bias_initializer=self._bias_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    bias_regularizer=self._bias_regularizer,
                    activity_regularizer=self._activity_regularizer,
                    kernel_constraint=self._kernel_constraint,
                    bias_constraint=self._bias_constraint,
                    name="cross_attention_output",
                )

                self._cross_attention_layer_norm = tf.keras.layers.LayerNormalization(
                    name="cross_attention_layer_norm",
                    axis=-1,
                    epsilon=self._layer_norm_epsilon,
                    dtype=tf.float32,
                )

        self._intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self._intermediate_size,
            activation=self._intermediate_activation,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            # This layer is always float32 for numeric stability.
            dtype=tf.float32,
            name="intermediate",
        )

        self._output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="output",
        )
        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)

        # Use float32 in layernorm for numeric stability.
        self._output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm",
            axis=-1,
            epsilon=self._layer_norm_epsilon,
            dtype=tf.float32,
        )
        super(TransformerBERT, self).build(input_shape)
示例#5
0
    def build(self, input_shape):
        """
        input_shape: [word_embeddings (3D), attention_mask(3D)]
        """
        input_tensor = input_shape[0]
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape) != 3:
            raise ValueError(
                "TransformerGPT2 expects a three-dimensional input of "
                "shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if hidden_size % self._num_heads != 0:
            raise ValueError(
                "The input size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, self._num_heads))
        self._attention_head_size = int(hidden_size // self._num_heads)

        self._pre_attention_norm = layer_normalization.GPT2LayerNormalization(
            name="ln_1/layer_norm", axis=-1, epsilon=1e-5, dtype=tf.float32)

        self._attention_layer = GPT2Attention(
            num_heads=self._num_heads,
            head_size=self._attention_head_size,
            dropout_rate=self._attention_dropout_rate,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            is_training=self.is_training,
            use_dropout=self.use_dropout,
            name="self_attention",
        )

        self._attention_output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention_output",
        )
        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._dropout_rate)
        self._attention_layer_norm = layer_normalization.GPT2LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        # If we have cross attention inside encoder
        if self._cross_attention_inside_encoder:

            self._pre_cross_attention_norm = layer_normalization.GPT2LayerNormalization(
                name="ln_1/pre_cross_layer_norm",
                axis=-1,
                epsilon=1e-5,
                dtype=tf.float32,
            )
            # Hard setting is_training to True, as we do not have to use cache here
            self._cross_attention_layer = GPT2Attention(
                num_heads=self._num_heads,
                head_size=self._attention_head_size,
                dropout_rate=self._attention_dropout_rate,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                is_training=self.is_training,
                use_dropout=self.use_dropout,
                name="cross_attention",
            )

            self._cross_attention_output_dense = dense_einsum.DenseEinsum(
                output_shape=hidden_size,
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                name="cross_attention_output",
            )
        if self._is_decoder:
            if self._share_attention_layers:
                self._pre_cross_attention_norm = self._pre_attention_norm
                self._cross_attention_layer = self._attention_layer
                self._cross_attention_output_dense = self._attention_output_dense
            else:
                self._pre_cross_attention_norm = layer_normalization.GPT2LayerNormalization(
                    name="ln_1/pre_cross_layer_norm",
                    axis=-1,
                    epsilon=1e-5,
                    dtype=tf.float32,
                )
                # Hard setting is_training to True, as we do not have to use cache here
                self._cross_attention_layer = GPT2Attention(
                    num_heads=self._num_heads,
                    head_size=self._attention_head_size,
                    dropout_rate=self._attention_dropout_rate,
                    kernel_initializer=self._kernel_initializer,
                    bias_initializer=self._bias_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    bias_regularizer=self._bias_regularizer,
                    activity_regularizer=self._activity_regularizer,
                    kernel_constraint=self._kernel_constraint,
                    bias_constraint=self._bias_constraint,
                    is_training=self.is_training,
                    use_dropout=self.use_dropout,
                    name="cross_attention",
                )

                self._cross_attention_output_dense = dense_einsum.DenseEinsum(
                    output_shape=hidden_size,
                    kernel_initializer=self._kernel_initializer,
                    bias_initializer=self._bias_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    bias_regularizer=self._bias_regularizer,
                    activity_regularizer=self._activity_regularizer,
                    kernel_constraint=self._kernel_constraint,
                    bias_constraint=self._bias_constraint,
                    name="cross_attention_output",
                )

        self._intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self._intermediate_size,
            activation=self._intermediate_activation,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            # This layer is always float32 for numeric stability.
            dtype=tf.float32,
            name="intermediate",
        )
        self._output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="output",
        )
        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
        super(TransformerGPT2, self).build(input_shape)
示例#6
0
    def build(self, input_shape):
        """
        Args:
            input_shape: [word_embeddings (3D), attention_mask(3D)]
        """
        input_tensor = input_shape[0]
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape) != 3:
            raise ValueError(
                "TransformerT5 expects a three-dimensional input of "
                "shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if not self._attention_head_size:
            if hidden_size % self._num_heads != 0:
                raise ValueError(
                    "The input size (%d) is not a multiple of the number of attention "
                    "heads (%d)" % (hidden_size, self._num_heads))
            self._attention_head_size = int(hidden_size // self._num_heads)

        self._pre_attention_norm = T5LayerNormalization(
            name="pre_attention_norm",
            axis=-1,
            epsilon=self._layer_norm_epsilon,
            dtype=tf.float32)

        self._attention_layer = T5Attention(
            num_heads=self._num_heads,
            head_size=self._attention_head_size,
            bidirectional=self._bidirectional,
            create_positonal_embedding=self._create_positonal_embedding,
            positional_buckets=self._positional_buckets,
            dropout_rate=self._attention_dropout_rate,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            is_training=self.is_training,
            use_dropout=self.use_dropout,
            name="self_attention",
            use_bias=self._use_bias,
        )

        self._attention_output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="self_attention_output",
            use_bias=self._use_bias,
        )

        if self._is_decoder:

            if self._share_attention_layers:
                self._pre_cross_attention_norm = self._pre_attention_norm
                self._cross_attention_layer = self._attention_layer
                self._cross_attention_output_dense = self._attention_output_dense
            else:
                self._pre_cross_attention_norm = T5LayerNormalization(
                    name="pre_cross_attention_norm",
                    axis=-1,
                    epsilon=self._layer_norm_epsilon,
                    dtype=tf.float32,
                )
                # Cross Attention Layer should always work under one mode `is_training = True`.
                # Because encoder_output is fixed.
                # Nothing to concat tenchincally like (K, V) in GPT2
                self._cross_attention_layer = T5Attention(
                    num_heads=self._num_heads,
                    head_size=self._attention_head_size,
                    bidirectional=self._bidirectional,
                    create_positonal_embedding=self.
                    _create_positonal_embedding,
                    positional_buckets=self._positional_buckets,
                    dropout_rate=self._attention_dropout_rate,
                    kernel_initializer=self._kernel_initializer,
                    bias_initializer=self._bias_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    bias_regularizer=self._bias_regularizer,
                    activity_regularizer=self._activity_regularizer,
                    kernel_constraint=self._kernel_constraint,
                    bias_constraint=self._bias_constraint,
                    is_training=self.is_training,
                    use_dropout=self.use_dropout,
                    name="cross_attention",
                    use_bias=self._use_bias,
                    is_cross_attention=True,  # hard code
                )

                self._cross_attention_output_dense = dense_einsum.DenseEinsum(
                    output_shape=hidden_size,
                    kernel_initializer=self._kernel_initializer,
                    bias_initializer=self._bias_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    bias_regularizer=self._bias_regularizer,
                    activity_regularizer=self._activity_regularizer,
                    kernel_constraint=self._kernel_constraint,
                    bias_constraint=self._bias_constraint,
                    name="cross_attention_output",
                    use_bias=self._use_bias,
                )

        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._dropout_rate)
        self._attention_layer_norm = T5LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=self._layer_norm_epsilon,
            dtype=tf.float32,
        )

        self._intermediate_dense = dense_einsum.DenseEinsum(
            output_shape=self._intermediate_size,
            activation=self._intermediate_activation,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            # This layer is always float32 for numeric stability.
            dtype=tf.float32,
            name="intermediate",
            use_bias=self._use_bias,
        )

        self._output_dense = dense_einsum.DenseEinsum(
            output_shape=hidden_size,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="output",
            use_bias=self._use_bias,
        )
        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    def __init__(
        self,
        num_heads,
        head_size,
        num_rand_blocks=3,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        from_block_size=64,
        to_block_size=64,
        dropout_rate=0.0,
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        name="attention",
        **kwargs,
    ):
        """Constructor for a multi-headed attention layer.

        Args:
          attention_type: Type of attention, needs to be one of ['original_full',
            'simulated_sparse', 'block_sparse'].
          num_attention_heads: (optional) int. Number of attention heads.
          num_rand_blocks: (optional) int. Number of random chunks per row.
          size_per_head: (optional) int. Size of each attention head.
          initializer_range: (optional) float. Range of the weight initializer.
          from_block_size: (optional) int. size of block in from sequence.
          to_block_size: (optional) int. size of block in to sequence.
          attention_probs_dropout_prob: (optional) float. Dropout probability of the
            attention probabilities.
          use_bias: Whether the layer uses a bias vector.
          seed: (Optional) int. Reandom seed for generating random mask.
          query_act: (optional) Activation function for the query transform.
          key_act: (optional) Activation function for the key transform.
          value_act: (optional) Activation function for the value transform.
          name: The name scope of this layer.
          **kwargs: others
        """
        kwargs["name"] = name
        super(BigBirdAttention, self).__init__(**kwargs)

        self._dropout_rate = dropout_rate
        self._kernel_initializer = tf.keras.initializers.get(
            kernel_initializer)
        self._bias_initializer = tf.keras.initializers.get(bias_initializer)
        self._kernel_regularizer = tf.keras.regularizers.get(
            kernel_regularizer)
        self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self._bias_constraint = tf.keras.constraints.get(bias_constraint)

        self.from_block_size = from_block_size
        self.to_block_size = to_block_size
        self._num_heads = num_heads
        self.num_rand_blocks = num_rand_blocks
        self._head_size = head_size

        self._query_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="query",
        )

        self._key_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="key",
        )

        self._value_dense = dense_einsum.DenseEinsum(
            output_shape=(self._num_heads, self._head_size),
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activity_regularizer=self._activity_regularizer,
            kernel_constraint=self._kernel_constraint,
            bias_constraint=self._bias_constraint,
            name="value",
        )