def call_decoder_predict(self, inputs):
        """Inputs will be pass to this method, when is_training = False and is_decoder = True. # noqa
        The need to cache the past `key` and `value` tensors for decoders \
        necessary while predicting, to make the inference/NLG
        faster in case of AutoRegressive Decoding.

        """

        input_ids = inputs["input_ids"]
        encoder_hidden_state = inputs["encoder_hidden_states"]
        decoder_encoder_mask = inputs["decoder_encoder_mask"]
        all_cache_key = inputs["all_cache_key"]
        all_cache_value = inputs["all_cache_value"]

        # Decoder don't need this

        # # When `mask_mode` is `causal` , input_mask is not required
        # if self.mask_mode in ['user_defined']:
        #     input_mask     = inputs['input_mask']

        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        # cache_length = tf.constant(0, dtype=tf.int32)

        def step_0_cache_length(_):
            return tf.constant(0, dtype=tf.int32)

        def step_other_cache_length(all_cache_key):
            past_length = tf.shape(all_cache_key)[3]
            # Why -1, because When iter 2
            # (our positional embedding should be 1 not 2 and so on)
            sequence_length = tf.shape(input_ids)[1] + past_length - 1
            return sequence_length

        sequence_length = tf.cond(
            tf.equal(tf.reduce_sum(all_cache_key), 0),
            lambda: step_0_cache_length(all_cache_key),
            lambda: step_other_cache_length(all_cache_key),
        )

        all_cache_key = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0)
        ]
        all_cache_value = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0)
        ]

        # If decoder is not sharing embeddings
        word_embeddings = self._embedding_layer(input_ids)
        embeddings = word_embeddings
        # Add word_embeddings + position_embeddings + type_embeddings
        if self.use_type_embeddings:
            type_embeddings = self._type_embeddings(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self.use_positonal_embeddings:
            positional_embeddings = self._position_embedding_layer(sequence_length)
            # Make it 3D for sum ( For decoder we decode one at a time)
            positional_embeddings = tf.expand_dims(positional_embeddings, 0)
            embeddings = embeddings + positional_embeddings
        # Norm + dropout
        embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)

        # Initialize `attention_mask` as empty list
        attention_mask = []
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)

        decoder_outputs = []
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            # Fetching
            cache_value = all_cache_value[i]
            cache_key = all_cache_key[i]

            embeddings, cache_key, cache_value = layer(
                [
                    embeddings,
                    attention_mask,
                    encoder_hidden_state,
                    decoder_encoder_mask,
                ],
                cache_key=cache_key,
                cache_value=cache_value,
            )

            # Updating
            all_cache_key[i] = cache_key
            all_cache_value[i] = cache_value

            decoder_outputs.append(embeddings)

        # Stack all layers key and value together
        # num_layers x batch_size x num_heads x sequence_length x (hidden_dimension/num_heads)
        all_cache_key = tf.stack(all_cache_key, axis=0, name="all_cache_key")
        all_cache_value = tf.stack(all_cache_value, axis=0, name="all_cache_value")

        # batch_size x sequence_length x embedding_size
        decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1])
        token_embeddings = decoder_outputs[-1]

        # token --> vocab ( batch_size x sequence_length x vocab_size)
        token_logits = tf.matmul(
            token_embeddings,
            self.get_embedding_table(),
            transpose_b=True,
            name="token_logits",
        )
        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)

        return {
            "all_cache_key": all_cache_key,
            "all_cache_value": all_cache_value,
            "token_embeddings": token_embeddings,
            "last_token_logits": last_token_logits,
        }
    def call_decoder(self, inputs):
        """Forward Pass for Decoder

        Args:
            inputs: dict
            inputs is a dict with keys  [`input_ids` , `input_mask`,
            `input_type_ids`, `encoder_hidden_states`, `decoder_encoder_mask`].
            These keys might or might not be present based on `mask_mode` and other criterias

        """
        input_ids = inputs["input_ids"]
        encoder_output = inputs["encoder_hidden_states"]
        decoder_encoder_mask = inputs["decoder_encoder_mask"]

        if self.mask_mode in ["user_defined"]:
            input_mask = inputs["input_mask"]

        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        sequence_length = tf.shape(input_ids)[1]

        # If decoder is not sharing embeddings
        word_embeddings = self._embedding_layer(input_ids)
        embeddings = word_embeddings
        # Add word_embeddings + position_embeddings + type_embeddings
        if self.use_type_embeddings:
            type_embeddings = self._type_embeddings(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self.use_positonal_embeddings:
            positional_embeddings = self._position_embedding_layer(tf.range(sequence_length))
            embeddings = embeddings + positional_embeddings
        # Norm + dropout
        embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)
        # Initialize `attention_mask` as empty list
        attention_mask = []

        if self.mask_mode == "user_defined":
            attention_mask = SelfAttentionMask()([embeddings, input_mask])
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)
        decoder_outputs = []
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            embeddings, _key, _value = layer([embeddings, attention_mask, encoder_output, decoder_encoder_mask])

            decoder_outputs.append(embeddings)

        # batch_size x sequence_length x embedding_size
        decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1])
        token_embeddings = decoder_outputs[-1]

        # token --> vocab ( batch_size x sequence_length x vocab_size)
        token_logits = tf.matmul(
            token_embeddings,
            self.get_embedding_table(),
            transpose_b=True,
            name="token_logits",
        )
        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)

        result = {
            "token_embeddings": token_embeddings,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }

        if self.return_all_layer_token_embeddings:
            result["all_layer_token_embeddings"] = decoder_encoder_mask
        return result
    def call_cross_attention_encoder(self, inputs):
        """[summary]

        Args:
            inputs ([type]): [description]
        """
        encoder_input_ids = inputs["encoder_input_ids"]
        decoder_input_ids = inputs["decoder_input_ids"]
        encoder_input_type_ids = None
        decoder_input_type_ids = None

        if self.use_type_embeddings:
            encoder_input_type_ids = inputs["encoder_input_type_ids"]
            decoder_input_type_ids = inputs["decoder_input_type_ids"]
        encoder_input_mask = None
        if self.mask_mode in ["user_defined", "prefix"]:
            encoder_input_mask = inputs["encoder_input_mask"]

        def get_embeddings(input_ids, input_type_ids):
            """Get embedding for encoder as well as decoder

            Args:
                input_ids ([type]): [description]
                input_type_ids ([type]): [description]
            """

            embeddings = self._embedding_layer(input_ids)
            sequence_length = tf.shape(input_ids)[1]
            # Add word_embeddings + position_embeddings + type_embeddings
            if self.use_type_embeddings:
                type_embeddings = self._type_embeddings(input_type_ids)
                embeddings = embeddings + type_embeddings
            if self.use_positonal_embeddings:
                positional_embeddings = self._position_embedding_layer(tf.range(sequence_length))
                embeddings = embeddings + positional_embeddings
            # Norm + dropout
            embeddings = self._embedding_norm(embeddings)
            embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)
            return embeddings

        encoder_embeddings = get_embeddings(encoder_input_ids, encoder_input_type_ids)
        decoder_embeddings = get_embeddings(decoder_input_ids, decoder_input_type_ids)

        # Initialize `encoder_attention_mask` as empty list
        encoder_attention_mask = []
        if self.mask_mode == "user_defined":
            encoder_attention_mask = SelfAttentionMask()([encoder_embeddings, encoder_input_mask])
        if self.mask_mode == "prefix":
            encoder_attention_mask = tf.map_fn(prefix_mask, encoder_input_mask, dtype=tf.float32)
        if self.mask_mode == "causal":
            encoder_attention_mask = CausalMask()(encoder_embeddings)

        # Decoder mask is always None
        decoder_attention_mask = CausalMask()(decoder_embeddings)
        decoder_encoder_mask = CrossAttentionMask()([decoder_input_ids, encoder_input_mask])
        decoder_outputs = []
        encoder_outputs = []

        # Encoder Layer
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            encoder_embeddings, _, _ = layer(
                [
                    encoder_embeddings,
                    encoder_attention_mask,
                    decoder_encoder_mask,  # dummy decoder_encoder_mask
                    encoder_embeddings,  # dummy encoder_hidden_states
                ],
                mode="encoder",
            )
            encoder_outputs.append(encoder_embeddings)

        # Decoder Layer
        encoder_hidden_states = encoder_outputs[-1]
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            decoder_embeddings, _, _ = layer(
                [decoder_embeddings, decoder_attention_mask, decoder_encoder_mask, encoder_hidden_states],
                mode="decoder",
            )
            decoder_outputs.append(decoder_embeddings)

        decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1])
        # First word of last layer outputs [CLS]
        # cls_token_tensor = tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(decoder_outputs[-1])
        # batch_size x embedding_size
        # cls_output = self._pooler_layer(cls_token_tensor)
        # batch_size x sequence_length x embedding_size
        token_embeddings = decoder_outputs[-1]

        # MLM Projection
        if self.use_mlm_layer:
            token_embeddings = self.mlm_layer(token_embeddings)
            # token --> vocab ( batch_size x sequence_length x vocab_size)
            token_logits = (
                tf.matmul(
                    token_embeddings,
                    self.get_embedding_table(),
                    transpose_b=True,
                    name="token_logits",
                )
                + self._last_logits_bias
            )
        else:

            # token --> vocab ( batch_size x sequence_length x vocab_size)
            token_logits = tf.matmul(
                token_embeddings,
                self.get_embedding_table(),
                transpose_b=True,
                name="token_logits",
            )

        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)

        result = {
            "token_embeddings": token_embeddings,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }

        if self.return_all_layer_token_embeddings:
            result["all_layer_token_embeddings"] = decoder_outputs
        return result
    def call_cross_attention_encoder_predict(self, inputs):
        """[summary]

        Args:
            inputs ([type]): [description]
        """

        encoder_input_ids = inputs["encoder_input_ids"]
        decoder_input_ids = inputs["decoder_input_ids"]
        encoder_input_type_ids = None
        decoder_input_type_ids = None

        if self.use_type_embeddings:
            encoder_input_type_ids = inputs["encoder_input_type_ids"]
            decoder_input_type_ids = inputs["decoder_input_type_ids"]
        encoder_input_mask = None
        if self.mask_mode in ["user_defined", "prefix"]:
            encoder_input_mask = inputs["encoder_input_mask"]

        # self.num_hidden_layers, batch_size, sequence_length, embeddingd_imension
        encoder_hidden_states = inputs["encoder_hidden_states"]
        all_cache_key = inputs["decoder_all_cache_key"]
        all_cache_value = inputs["decoder_all_cache_value"]

        def get_encoder_embeddings(input_ids, input_type_ids):
            """Get embedding for encoder as well as decoder

            Args:
                input_ids ([type]): [description]
                input_type_ids ([type]): [description]
            """
            embeddings = self._embedding_layer(input_ids)
            sequence_length = tf.shape(input_ids)[1]
            # Add word_embeddings + position_embeddings + type_embeddings
            if self.use_type_embeddings:
                type_embeddings = self._type_embeddings(input_type_ids)
                embeddings = embeddings + type_embeddings
            if self.use_positonal_embeddings:
                positional_embeddings = self._position_embedding_layer(tf.range(sequence_length))
                embeddings = embeddings + positional_embeddings
            # Norm + dropout
            embeddings = self._embedding_norm(embeddings)
            embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)
            return embeddings

        # this function is slightly different from the other function
        # because, we do not need tf.range(sequence_length)
        # we need it for (one word) from, step 1 onwards, as we decode
        # word by word. So we use all_cache_key for getting the past_length

        def get_decoder_embeddings_step_other(input_ids, input_type_ids):
            """Get embedding for encoder as well as decoder

            Args:
                input_ids ([type]): [description]
                input_type_ids ([type]): [description]
            """

            def step_0_cache_length(_):
                return tf.constant(0, dtype=tf.int32)

            def step_other_cache_length(all_cache_key):
                past_length = tf.shape(all_cache_key)[3]
                # Why -1, because When iter 2 (our positional
                # embedding should be 1 not 2 and so on)
                sequence_length = tf.shape(input_ids)[1] + past_length - 1
                return sequence_length

            sequence_length = tf.cond(
                tf.equal(tf.reduce_sum(all_cache_key), 0),
                lambda: step_0_cache_length(all_cache_key),
                lambda: step_other_cache_length(all_cache_key),
            )

            embeddings = self._embedding_layer(input_ids)
            # Add word_embeddings + position_embeddings + type_embeddings
            if self.use_type_embeddings:
                type_embeddings = self._type_embeddings(input_type_ids)
                embeddings = embeddings + type_embeddings
            if self.use_positonal_embeddings:
                positional_embeddings = self._position_embedding_layer(sequence_length)
                # Make it 3D for sum ( For decoder we decode one at a time)
                positional_embeddings = tf.expand_dims(positional_embeddings, 0)
                embeddings = embeddings + positional_embeddings
            # Norm + dropout
            embeddings = self._embedding_norm(embeddings)
            embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)
            return embeddings

        # Encoder embeddings remains same throughout the decoding process
        # so we have to calculate it only once
        # So , we check if cache_key == 0, if its 0 its step 0
        # else, pass a dummy encoder_embeddings, as we dont have to use it from step1
        # because, what we need from encoder is encoder_hidden_states_batch

        encoder_embeddings = tf.cond(
            tf.equal(tf.reduce_sum(all_cache_key), 0.0),
            lambda: get_encoder_embeddings(encoder_input_ids, encoder_input_type_ids),
            lambda: tf.zeros_like(encoder_hidden_states),  # dummy
        )

        decoder_embeddings = tf.cond(
            tf.equal(tf.reduce_sum(all_cache_key), 0.0),
            lambda: get_encoder_embeddings(decoder_input_ids, decoder_input_type_ids),
            lambda: get_decoder_embeddings_step_other(decoder_input_ids, decoder_input_type_ids),
        )

        # Initialize `encoder_attention_mask` as empty list
        encoder_attention_mask = []
        if self.mask_mode == "user_defined":
            encoder_attention_mask = SelfAttentionMask()([encoder_embeddings, encoder_input_mask])
        if self.mask_mode == "prefix":
            encoder_attention_mask = tf.map_fn(prefix_mask, encoder_input_mask, dtype=tf.float32)
        if self.mask_mode == "causal":
            encoder_attention_mask = CausalMask()(encoder_embeddings)

        # Decoder mask is always None
        decoder_attention_mask = CausalMask()(decoder_embeddings)
        decoder_encoder_mask = CrossAttentionMask()([decoder_input_ids, encoder_input_mask])

        all_cache_key = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0)
        ]
        all_cache_value = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0)
        ]

        def calculate_encoder_hidden_state(encoder_embeddings):
            # Encoder Layer
            encoder_outputs = []
            for i in range(self.num_hidden_layers):
                layer = self._transformer_layers[i]
                cache_key = all_cache_key[i]
                cache_value = all_cache_value[i]
                encoder_embeddings, _, _ = layer(
                    [
                        encoder_embeddings,
                        encoder_attention_mask,
                        decoder_encoder_mask,  # decoder_encoder_mask
                        encoder_embeddings,
                    ],
                    mode="encoder",
                    cache_key=cache_key,
                    cache_value=cache_value,
                )
                encoder_outputs.append(encoder_embeddings)
            encoder_hidden_states = encoder_outputs[-1]
            return encoder_hidden_states

        # While decoding we have to calculate it only once
        def use_cache_encoder():
            return tf.identity(inputs["encoder_hidden_states"])

        encoder_hidden_states = tf.cond(
            tf.equal(tf.reduce_sum(inputs["encoder_hidden_states"]), 0.0),
            lambda: calculate_encoder_hidden_state(encoder_embeddings),
            lambda: use_cache_encoder(),
        )
        # Decoder layer
        decoder_outputs = []
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            # Fetching
            cache_value = all_cache_value[i]
            cache_key = all_cache_key[i]
            decoder_embeddings, cache_key, cache_value = layer(
                [
                    decoder_embeddings,
                    decoder_attention_mask,
                    decoder_encoder_mask,
                    encoder_hidden_states,
                ],
                mode="decoder",
                cache_key=cache_key,
                cache_value=cache_value,
            )
            # Updating
            all_cache_key[i] = cache_key
            all_cache_value[i] = cache_value
            decoder_outputs.append(decoder_embeddings)

        # Stack all layers key and value together
        # num_layers x batch_size x num_heads x sequence_length x
        # (hidden_dimension/num_heads) # noqa
        all_cache_key = tf.stack(all_cache_key, axis=0, name="decoder_all_cache_key")
        all_cache_value = tf.stack(all_cache_value, axis=0, name="decoder_all_cache_value")
        # First word of last layer outputs [CLS]
        # cls_token_tensor = tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(decoder_outputs[-1])
        # batch_size x embedding_size
        # cls_output = self._pooler_layer(cls_token_tensor)
        # batch_size x sequence_length x embedding_size
        token_embeddings = decoder_outputs[-1]

        # MLM Projection
        if self.use_mlm_layer:
            token_embeddings = self.mlm_layer(token_embeddings)
            # token --> vocab ( batch_size x sequence_length x vocab_size)
            token_logits = (
                tf.matmul(
                    token_embeddings,
                    self.get_embedding_table(),
                    transpose_b=True,
                    name="token_logits",
                )
                + self._last_logits_bias
            )
        else:

            # token --> vocab ( batch_size x sequence_length x vocab_size)
            token_logits = tf.matmul(
                token_embeddings,
                self.get_embedding_table(),
                transpose_b=True,
                name="token_logits",
            )

        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)
        return {
            "encoder_hidden_states": encoder_hidden_states,
            "decoder_all_cache_key": all_cache_key,
            "decoder_all_cache_value": all_cache_value,
            "token_embeddings": token_embeddings,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }
    def call_training(self, inputs):
        """Forward Pass for BERT

        Args:
            inputs: dict
            inputs is a dict with keys  [`input_ids` , `input_mask`, `input_type_ids`].
            These keys might or might not be present based on `mask_mode` and other criterias

        """
        input_ids = inputs["input_ids"]
        # When `mask_mode` is `causal` , input_mask is not required
        if self.mask_mode in ["user_defined", "prefix"]:
            input_mask = inputs["input_mask"]
        # Default True in BERT
        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        sequence_length = tf.shape(input_ids)[1]
        word_embeddings = self._embedding_layer(input_ids)
        embeddings = word_embeddings
        # Add word_embeddings + position_embeddings + type_embeddings
        if self.use_type_embeddings:
            type_embeddings = self._type_embeddings(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self.use_positonal_embeddings:
            positional_embeddings = self._position_embedding_layer(tf.range(sequence_length))
            embeddings = embeddings + positional_embeddings

        # Norm + dropout
        embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)

        # Initialize `attention_mask` as empty list
        attention_mask = []
        if self.mask_mode == "user_defined":
            attention_mask = SelfAttentionMask()([embeddings, input_mask])
        if self.mask_mode == "prefix":
            attention_mask = tf.map_fn(prefix_mask, input_mask, dtype=tf.float32)
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)

        encoder_outputs = []
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            embeddings, _, _ = layer([embeddings, attention_mask])
            embeddings = tf.identity(embeddings, name="token_embeddings_layer_{}".format(i))
            encoder_outputs.append(embeddings)

        # Last layer output has to be normalized in GPT2
        encoder_outputs[-1] = self._last_layer_norm(encoder_outputs[-1])
        # batch_size x sequence_length x embedding_size
        token_embeddings = encoder_outputs[-1]

        # token --> vocab ( batch_size x sequence_length x vocab_size)
        token_logits = tf.matmul(
            token_embeddings,
            self.get_embedding_table(),
            transpose_b=True,
            name="token_logits",
        )
        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)
        last_token_logits = tf.identity(last_token_logits, name="last_token_logits")

        result = {
            "token_embeddings": token_embeddings,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }
        if self.return_all_layer_token_embeddings:
            result["all_layer_token_embeddings"] = encoder_outputs
        return result
    def call_predict(self, inputs):
        """Inputs will be pass to this method, when is_training = False.
        The need to cache the past `key` and `value` tensors are \
        necessary while predicting, to make the inference/NLG
        faster in case of AutoRegressive Decoding.

        """
        input_ids_mod = inputs["input_ids"]
        all_cache_key = inputs["all_cache_key"]
        all_cache_value = inputs["all_cache_value"]
        past_length = inputs["past_length"]

        # Come from kwargs
        if self.mask_mode in ["user_defined", "prefix"]:
            input_mask = inputs["input_mask"]
        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        # Convert past_length 2D to 1D
        past_length = tf.squeeze(past_length, 0)

        # In case of variable batch decoding, we will pad the inputs with -1
        # So, we will replace -1 with 0, because -1 \
        # is not a valid index in word embeddings
        # >> input_ids_mod = [[ 1, 5, 7,  8,  10],
        #                       2, 3, -1, -1, -1]]
        #
        # >> input_ids     = [[1, 5, 7, 8,10],
        #                      2, 3, 0, 0, 0]]

        input_ids = input_ids_mod * tf.cast(tf.not_equal(input_ids_mod, -1), tf.int32)
        sequence_length = tf.shape(input_ids)[1]

        # Asserting
        tf.assert_equal(tf.shape(all_cache_value)[0], self.num_hidden_layers)

        # Step 0 of inference. For step0, we do not have valid cache. We pass zero tensor
        def step_0(input_ids):
            sequence_length = tf.shape(input_ids)[1]
            position_embeddings = self._position_embedding_layer(tf.range(sequence_length))
            return sequence_length, position_embeddings

        # From step_1 (autoregressive mode starts) onwards, we need to account for
        # `past_length` of previous words (inputs + generated) . Due to our logic,
        # we need to take a transpose of `position_embeddings` in this specific setting
        def step_other(input_ids):
            sequence_length = tf.shape(input_ids)[1]
            # Because past_length varies with batch
            position_embeddings = self._position_embedding_layer(past_length + sequence_length)
            position_embeddings = tf.transpose(position_embeddings, [1, 0, 2])
            return sequence_length, position_embeddings

        # Condition to switch functions
        # if `sum(past_length) = 0` , means no outputs has been generated. \
        # the given inputs is the first input
        sequence_length, positional_embeddings = tf.cond(
            tf.equal(tf.reduce_sum(past_length), 0),
            lambda: step_0(input_ids),
            lambda: step_other(input_ids),
        )
        all_cache_key = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0)
        ]
        all_cache_value = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0)
        ]

        word_embeddings = self._embedding_layer(input_ids)
        embeddings = word_embeddings
        # Add word_embeddings + position_embeddings + type_embeddings
        if self.use_type_embeddings:
            type_embeddings = self._type_embeddings(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self.use_positonal_embeddings:
            embeddings = embeddings + positional_embeddings

        # Norm + dropout
        embeddings = self._embedding_dropout(embeddings, training=self.use_dropout)

        # Initialize `attention_mask` as empty list
        attention_mask = []
        if self.mask_mode == "user_defined":
            attention_mask = SelfAttentionMask()([embeddings, input_mask])
        if self.mask_mode == "prefix":
            attention_mask = tf.map_fn(prefix_mask, input_mask, fn_output_signature=tf.float32)
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)

        encoder_outputs = []
        # Make all -1 positions to 0 (as -1 represents padding in the input)
        mask_values = tf.cast(tf.not_equal(input_ids_mod, -1), tf.float32)
        # We want zero values , where embeddings inputs where 0 (by replacing PAD -1)
        # So we use the mask and multiply it with embeddings
        embeddings = embeddings * tf.expand_dims(mask_values, -1)
        for i in range(self.num_hidden_layers):

            layer = self._transformer_layers[i]
            # Fetching
            cache_value = all_cache_value[i]
            cache_key = all_cache_key[i]

            embeddings, cache_key, cache_value = layer(
                [embeddings, attention_mask],
                cache_key=cache_key,
                cache_value=cache_value,
            )
            # Updating
            all_cache_key[i] = cache_key
            all_cache_value[i] = cache_value

            # Mask next layer embedding (PAD positions to 0)
            embeddings = tf.identity(
                embeddings * tf.expand_dims(mask_values, -1),
                name="encoder_outputs_{}".format(i),
            )
            encoder_outputs.append(embeddings)

        def step_0_gather(past_length, token_embeddings):
            cache_length = tf.reduce_sum(tf.cast(tf.not_equal(input_ids_mod, -1), tf.int32), axis=1) - 1
            # Getting corresponding last token tensor and last token logits
            last_token_tensor = tf.gather_nd(token_embeddings, tf.expand_dims(cache_length, axis=1), batch_dims=1)
            past_length = past_length + cache_length
            return past_length, last_token_tensor

        def step_other_gather(past_length, token_embeddings):
            past_length = past_length + sequence_length
            last_token_tensor = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_embeddings)
            return past_length, last_token_tensor

        # batch_size x sequence_length x embedding_size
        token_embeddings = self._last_layer_norm(encoder_outputs[-1])

        # Condition to switch functionsn (When batch_size > 1,
        # past_length will be different for each entry)
        # if `sum(past_length) = 0` , means no outputs has been generated.
        # the given inputs is the first input
        past_length, last_token_tensor = tf.cond(
            tf.equal(tf.reduce_sum(past_length), 0),
            lambda: step_0_gather(past_length, token_embeddings),
            lambda: step_other_gather(past_length, token_embeddings),
        )

        # token --> vocab ( batch_size x sequence_length x vocab_size)
        last_token_logits = tf.matmul(
            last_token_tensor,
            self.get_embedding_table(),
            transpose_b=True,
            name="token_logits",
        )
        # last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)

        # Expand dims of past_length back to 2D
        past_length = tf.expand_dims(past_length, 0, name="past_length")
        # Stack all layers key and value together
        # num_layers x batch_size x num_heads x sequence_length x (hidden_dimension/num_heads)
        all_cache_key = tf.stack(all_cache_key, axis=0, name="all_cache_key")
        all_cache_value = tf.stack(all_cache_value, axis=0, name="all_cache_value")

        return {
            "token_embeddings": token_embeddings,
            "last_token_logits": last_token_logits,
            "past_length": past_length,
            "all_cache_key": all_cache_key,
            "all_cache_value": all_cache_value,
        }
Exemple #7
0
    def call_decoder_predict(self, inputs):
        """Inputs will be pass to this method,
        when is_training = False and is_decoder = True.
        The need to cache the past `key` and `value` tensors for \
        decoders \necessary while predicting, to make the inference/NLG
        faster in case of AutoRegressive Decoding.

        """

        input_ids = inputs["input_ids"]
        encoder_hidden_state = inputs["encoder_hidden_states"]
        decoder_encoder_mask = inputs["decoder_encoder_mask"]
        all_cache_key = inputs["all_cache_key"]
        all_cache_value = inputs["all_cache_value"]

        # When `mask_mode` is `causal` , input_mask is not required
        # if self.mask_mode in ["user_defined"]:
        #     input_mask = inputs["input_mask"]

        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        # sequence_length = tf.shape(input_ids)[1]

        all_cache_key = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_key,
                                 num_or_size_splits=self.num_hidden_layers,
                                 axis=0)
        ]
        all_cache_value = [
            tf.squeeze(item, axis=0)
            for item in tf.split(all_cache_value,
                                 num_or_size_splits=self.num_hidden_layers,
                                 axis=0)
        ]

        # If decoder is not sharing embeddings
        word_embeddings = self._embedding_layer(input_ids)
        embeddings = word_embeddings
        # Add word_embeddings + position_embeddings + type_embeddings
        if self.use_type_embeddings:
            type_embeddings = self._type_embeddings(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self.use_positonal_embeddings:
            positional_embeddings = self._position_embedding_layer(
                input_type_ids)
            embeddings = embeddings + positional_embeddings

        # Norm + dropout
        embeddings = self._embedding_dropout(embeddings,
                                             training=self.use_dropout)

        # Initialize `attention_mask` as empty list
        attention_mask = []
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)

        decoder_outputs = []
        position_bias = None
        decoder_encoder_position_bias = None
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            # Fetching
            cache_value = all_cache_value[i]
            cache_key = all_cache_key[i]

            (
                embeddings,
                position_bias,
                decoder_encoder_position_bias,
                cache_key,
                cache_value,
            ) = layer(
                [
                    embeddings,
                    attention_mask,
                    encoder_hidden_state,
                    decoder_encoder_mask,
                ],
                position_bias=position_bias,
                decoder_encoder_position_bias=decoder_encoder_position_bias,
                cache_key=cache_key,
                cache_value=cache_value,
            )

            # Updating
            all_cache_key[i] = cache_key
            all_cache_value[i] = cache_value

            decoder_outputs.append(embeddings)

        # Stack all layers key and value together
        # num_layers x batch_size x num_heads x sequence_length x (hidden_dimension/num_heads) # noqa
        all_cache_key = tf.stack(all_cache_key, axis=0, name="all_cache_key")
        all_cache_value = tf.stack(all_cache_value,
                                   axis=0,
                                   name="all_cache_value")

        decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1])
        # batch_size x sequence_length x embedding_size
        token_embeddings = self._last_layer_dropout(decoder_outputs[-1])

        # token --> vocab ( batch_size x sequence_length x vocab_size)
        token_logits = tf.matmul(
            token_embeddings,
            self.get_embedding_table(),
            transpose_b=True,
            name="token_logits",
        )
        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(
            token_logits)

        return {
            "all_cache_key": all_cache_key,
            "all_cache_value": all_cache_value,
            "token_embeddings": token_embeddings,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }
    def call_decoder(self, inputs):
        """Forward Pass for Decoder

        Args:
            inputs: dict
            inputs is a dict with keys  [`input_ids` , `input_mask`, `input_type_ids`, \
             `encoder_hidden_states`, `decoder_encoder_mask`].
            These keys might or might not be present based on `mask_mode` and other criterias

        """
        input_ids = inputs["input_ids"]
        encoder_output = inputs["encoder_hidden_states"]
        decoder_encoder_mask = inputs["decoder_encoder_mask"]

        if self.mask_mode in ["user_defined"]:
            input_mask = inputs["input_mask"]

        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        sequence_length = tf.shape(input_ids)[1]

        # If decoder is not sharing embeddings
        if self.initialize_embeddings:
            word_embeddings = self._embedding_layer(input_ids)
            embeddings = word_embeddings
            # Add word_embeddings + position_embeddings + type_embeddings
            if self.use_type_embeddings:
                type_embeddings = self._type_embeddings(input_type_ids)
                embeddings = embeddings + type_embeddings
            if self.use_positonal_embeddings:
                positional_embeddings = self._position_embedding_layer(
                    tf.range(sequence_length))
                embeddings = embeddings + positional_embeddings
        else:
            embeddings = inputs["decoder_embeddings"]
        # Norm + dropout
        embeddings = self._embedding_dropout(embeddings,
                                             training=self.use_dropout)
        # Initialize `attention_mask` as empty list
        attention_mask = []

        if self.mask_mode == "user_defined":
            attention_mask = SelfAttentionMask()([embeddings, input_mask])
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)
        decoder_outputs = []
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            embeddings, _key, _value = layer([
                embeddings, attention_mask, encoder_output,
                decoder_encoder_mask
            ])

            decoder_outputs.append(embeddings)

        # batch_size x sequence_length x embedding_size
        token_embeddings = decoder_outputs[-1]
        return {
            "token_embeddings": token_embeddings,
            "all_layer_token_embeddings": decoder_outputs,
        }
    def call_training(self, inputs):
        """Forward Pass for BERT

        Args:
            inputs: dict
            inputs is a dict with keys  [`input_ids` , `input_mask`, `input_type_ids`].
            These keys might or might not be present based on `mask_mode` and other criterias

        """
        input_ids = inputs["input_ids"]
        # When `mask_mode` is `causal` , input_mask is not required
        if self.mask_mode in ["user_defined", "prefix"]:
            input_mask = inputs["input_mask"]
        # Default True in BERT
        if self.use_type_embeddings:
            input_type_ids = inputs["input_type_ids"]

        sequence_length = tf.shape(input_ids)[1]
        word_embeddings = self._embedding_layer(input_ids)
        embeddings = word_embeddings
        # Add word_embeddings + position_embeddings + type_embeddings
        if self.use_type_embeddings:
            type_embeddings = self._type_embeddings(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self.use_positonal_embeddings:
            positional_embeddings = self._position_embedding_layer(
                tf.range(sequence_length))
            embeddings = embeddings + positional_embeddings

        # Norm + dropout
        embeddings = self._embedding_norm(embeddings)
        embeddings = self._embedding_dropout(embeddings,
                                             training=self.use_dropout)

        # Initialize `attention_mask` as empty list
        attention_mask = []
        if self.mask_mode == "user_defined":
            attention_mask = SelfAttentionMask()([embeddings, input_mask])
        if self.mask_mode == "prefix":
            attention_mask = tf.map_fn(prefix_mask,
                                       input_mask,
                                       dtype=tf.float32)
        if self.mask_mode == "causal":
            attention_mask = CausalMask()(embeddings)

        encoder_outputs = []
        for i in range(self.num_hidden_layers):
            layer = self._transformer_layers[i]
            embeddings, _, _ = layer([embeddings, attention_mask])
            encoder_outputs.append(embeddings)

        # First word of last layer outputs [CLS]
        cls_token_tensor = tf.keras.layers.Lambda(
            lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(encoder_outputs[-1])
        # batch_size x embedding_size
        cls_output = self._pooler_layer(cls_token_tensor)
        # batch_size x sequence_length x embedding_size
        token_embeddings = encoder_outputs[-1]

        # unilm has one
        token_embeddings_extra = self.mlm_layer(token_embeddings)

        # token --> vocab ( batch_size x sequence_length x vocab_size)
        token_logits = (tf.matmul(
            token_embeddings_extra,
            self.get_embedding_table(),
            transpose_b=True,
            name="token_logits",
        ) + self._last_logits_bias)

        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(
            token_logits)

        return {
            "cls_output": cls_output,
            "token_embeddings": token_embeddings_extra,
            "all_layer_token_embeddings": encoder_outputs,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }