def call_decoder_predict(self, inputs): """Inputs will be pass to this method, when is_training = False and is_decoder = True. # noqa The need to cache the past `key` and `value` tensors for decoders \ necessary while predicting, to make the inference/NLG faster in case of AutoRegressive Decoding. """ input_ids = inputs["input_ids"] encoder_hidden_state = inputs["encoder_hidden_states"] decoder_encoder_mask = inputs["decoder_encoder_mask"] all_cache_key = inputs["all_cache_key"] all_cache_value = inputs["all_cache_value"] # Decoder don't need this # # When `mask_mode` is `causal` , input_mask is not required # if self.mask_mode in ['user_defined']: # input_mask = inputs['input_mask'] if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] # cache_length = tf.constant(0, dtype=tf.int32) def step_0_cache_length(_): return tf.constant(0, dtype=tf.int32) def step_other_cache_length(all_cache_key): past_length = tf.shape(all_cache_key)[3] # Why -1, because When iter 2 # (our positional embedding should be 1 not 2 and so on) sequence_length = tf.shape(input_ids)[1] + past_length - 1 return sequence_length sequence_length = tf.cond( tf.equal(tf.reduce_sum(all_cache_key), 0), lambda: step_0_cache_length(all_cache_key), lambda: step_other_cache_length(all_cache_key), ) all_cache_key = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0) ] all_cache_value = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0) ] # If decoder is not sharing embeddings word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer(sequence_length) # Make it 3D for sum ( For decoder we decode one at a time) positional_embeddings = tf.expand_dims(positional_embeddings, 0) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) decoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] # Fetching cache_value = all_cache_value[i] cache_key = all_cache_key[i] embeddings, cache_key, cache_value = layer( [ embeddings, attention_mask, encoder_hidden_state, decoder_encoder_mask, ], cache_key=cache_key, cache_value=cache_value, ) # Updating all_cache_key[i] = cache_key all_cache_value[i] = cache_value decoder_outputs.append(embeddings) # Stack all layers key and value together # num_layers x batch_size x num_heads x sequence_length x (hidden_dimension/num_heads) all_cache_key = tf.stack(all_cache_key, axis=0, name="all_cache_key") all_cache_value = tf.stack(all_cache_value, axis=0, name="all_cache_value") # batch_size x sequence_length x embedding_size decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1]) token_embeddings = decoder_outputs[-1] # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits) return { "all_cache_key": all_cache_key, "all_cache_value": all_cache_value, "token_embeddings": token_embeddings, "last_token_logits": last_token_logits, }
def call_decoder(self, inputs): """Forward Pass for Decoder Args: inputs: dict inputs is a dict with keys [`input_ids` , `input_mask`, `input_type_ids`, `encoder_hidden_states`, `decoder_encoder_mask`]. These keys might or might not be present based on `mask_mode` and other criterias """ input_ids = inputs["input_ids"] encoder_output = inputs["encoder_hidden_states"] decoder_encoder_mask = inputs["decoder_encoder_mask"] if self.mask_mode in ["user_defined"]: input_mask = inputs["input_mask"] if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] sequence_length = tf.shape(input_ids)[1] # If decoder is not sharing embeddings word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer(tf.range(sequence_length)) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "user_defined": attention_mask = SelfAttentionMask()([embeddings, input_mask]) if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) decoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] embeddings, _key, _value = layer([embeddings, attention_mask, encoder_output, decoder_encoder_mask]) decoder_outputs.append(embeddings) # batch_size x sequence_length x embedding_size decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1]) token_embeddings = decoder_outputs[-1] # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits) result = { "token_embeddings": token_embeddings, "token_logits": token_logits, "last_token_logits": last_token_logits, } if self.return_all_layer_token_embeddings: result["all_layer_token_embeddings"] = decoder_encoder_mask return result
def call_cross_attention_encoder(self, inputs): """[summary] Args: inputs ([type]): [description] """ encoder_input_ids = inputs["encoder_input_ids"] decoder_input_ids = inputs["decoder_input_ids"] encoder_input_type_ids = None decoder_input_type_ids = None if self.use_type_embeddings: encoder_input_type_ids = inputs["encoder_input_type_ids"] decoder_input_type_ids = inputs["decoder_input_type_ids"] encoder_input_mask = None if self.mask_mode in ["user_defined", "prefix"]: encoder_input_mask = inputs["encoder_input_mask"] def get_embeddings(input_ids, input_type_ids): """Get embedding for encoder as well as decoder Args: input_ids ([type]): [description] input_type_ids ([type]): [description] """ embeddings = self._embedding_layer(input_ids) sequence_length = tf.shape(input_ids)[1] # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer(tf.range(sequence_length)) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_norm(embeddings) embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) return embeddings encoder_embeddings = get_embeddings(encoder_input_ids, encoder_input_type_ids) decoder_embeddings = get_embeddings(decoder_input_ids, decoder_input_type_ids) # Initialize `encoder_attention_mask` as empty list encoder_attention_mask = [] if self.mask_mode == "user_defined": encoder_attention_mask = SelfAttentionMask()([encoder_embeddings, encoder_input_mask]) if self.mask_mode == "prefix": encoder_attention_mask = tf.map_fn(prefix_mask, encoder_input_mask, dtype=tf.float32) if self.mask_mode == "causal": encoder_attention_mask = CausalMask()(encoder_embeddings) # Decoder mask is always None decoder_attention_mask = CausalMask()(decoder_embeddings) decoder_encoder_mask = CrossAttentionMask()([decoder_input_ids, encoder_input_mask]) decoder_outputs = [] encoder_outputs = [] # Encoder Layer for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] encoder_embeddings, _, _ = layer( [ encoder_embeddings, encoder_attention_mask, decoder_encoder_mask, # dummy decoder_encoder_mask encoder_embeddings, # dummy encoder_hidden_states ], mode="encoder", ) encoder_outputs.append(encoder_embeddings) # Decoder Layer encoder_hidden_states = encoder_outputs[-1] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] decoder_embeddings, _, _ = layer( [decoder_embeddings, decoder_attention_mask, decoder_encoder_mask, encoder_hidden_states], mode="decoder", ) decoder_outputs.append(decoder_embeddings) decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1]) # First word of last layer outputs [CLS] # cls_token_tensor = tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(decoder_outputs[-1]) # batch_size x embedding_size # cls_output = self._pooler_layer(cls_token_tensor) # batch_size x sequence_length x embedding_size token_embeddings = decoder_outputs[-1] # MLM Projection if self.use_mlm_layer: token_embeddings = self.mlm_layer(token_embeddings) # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = ( tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) + self._last_logits_bias ) else: # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits) result = { "token_embeddings": token_embeddings, "token_logits": token_logits, "last_token_logits": last_token_logits, } if self.return_all_layer_token_embeddings: result["all_layer_token_embeddings"] = decoder_outputs return result
def call_cross_attention_encoder_predict(self, inputs): """[summary] Args: inputs ([type]): [description] """ encoder_input_ids = inputs["encoder_input_ids"] decoder_input_ids = inputs["decoder_input_ids"] encoder_input_type_ids = None decoder_input_type_ids = None if self.use_type_embeddings: encoder_input_type_ids = inputs["encoder_input_type_ids"] decoder_input_type_ids = inputs["decoder_input_type_ids"] encoder_input_mask = None if self.mask_mode in ["user_defined", "prefix"]: encoder_input_mask = inputs["encoder_input_mask"] # self.num_hidden_layers, batch_size, sequence_length, embeddingd_imension encoder_hidden_states = inputs["encoder_hidden_states"] all_cache_key = inputs["decoder_all_cache_key"] all_cache_value = inputs["decoder_all_cache_value"] def get_encoder_embeddings(input_ids, input_type_ids): """Get embedding for encoder as well as decoder Args: input_ids ([type]): [description] input_type_ids ([type]): [description] """ embeddings = self._embedding_layer(input_ids) sequence_length = tf.shape(input_ids)[1] # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer(tf.range(sequence_length)) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_norm(embeddings) embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) return embeddings # this function is slightly different from the other function # because, we do not need tf.range(sequence_length) # we need it for (one word) from, step 1 onwards, as we decode # word by word. So we use all_cache_key for getting the past_length def get_decoder_embeddings_step_other(input_ids, input_type_ids): """Get embedding for encoder as well as decoder Args: input_ids ([type]): [description] input_type_ids ([type]): [description] """ def step_0_cache_length(_): return tf.constant(0, dtype=tf.int32) def step_other_cache_length(all_cache_key): past_length = tf.shape(all_cache_key)[3] # Why -1, because When iter 2 (our positional # embedding should be 1 not 2 and so on) sequence_length = tf.shape(input_ids)[1] + past_length - 1 return sequence_length sequence_length = tf.cond( tf.equal(tf.reduce_sum(all_cache_key), 0), lambda: step_0_cache_length(all_cache_key), lambda: step_other_cache_length(all_cache_key), ) embeddings = self._embedding_layer(input_ids) # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer(sequence_length) # Make it 3D for sum ( For decoder we decode one at a time) positional_embeddings = tf.expand_dims(positional_embeddings, 0) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_norm(embeddings) embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) return embeddings # Encoder embeddings remains same throughout the decoding process # so we have to calculate it only once # So , we check if cache_key == 0, if its 0 its step 0 # else, pass a dummy encoder_embeddings, as we dont have to use it from step1 # because, what we need from encoder is encoder_hidden_states_batch encoder_embeddings = tf.cond( tf.equal(tf.reduce_sum(all_cache_key), 0.0), lambda: get_encoder_embeddings(encoder_input_ids, encoder_input_type_ids), lambda: tf.zeros_like(encoder_hidden_states), # dummy ) decoder_embeddings = tf.cond( tf.equal(tf.reduce_sum(all_cache_key), 0.0), lambda: get_encoder_embeddings(decoder_input_ids, decoder_input_type_ids), lambda: get_decoder_embeddings_step_other(decoder_input_ids, decoder_input_type_ids), ) # Initialize `encoder_attention_mask` as empty list encoder_attention_mask = [] if self.mask_mode == "user_defined": encoder_attention_mask = SelfAttentionMask()([encoder_embeddings, encoder_input_mask]) if self.mask_mode == "prefix": encoder_attention_mask = tf.map_fn(prefix_mask, encoder_input_mask, dtype=tf.float32) if self.mask_mode == "causal": encoder_attention_mask = CausalMask()(encoder_embeddings) # Decoder mask is always None decoder_attention_mask = CausalMask()(decoder_embeddings) decoder_encoder_mask = CrossAttentionMask()([decoder_input_ids, encoder_input_mask]) all_cache_key = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0) ] all_cache_value = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0) ] def calculate_encoder_hidden_state(encoder_embeddings): # Encoder Layer encoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] cache_key = all_cache_key[i] cache_value = all_cache_value[i] encoder_embeddings, _, _ = layer( [ encoder_embeddings, encoder_attention_mask, decoder_encoder_mask, # decoder_encoder_mask encoder_embeddings, ], mode="encoder", cache_key=cache_key, cache_value=cache_value, ) encoder_outputs.append(encoder_embeddings) encoder_hidden_states = encoder_outputs[-1] return encoder_hidden_states # While decoding we have to calculate it only once def use_cache_encoder(): return tf.identity(inputs["encoder_hidden_states"]) encoder_hidden_states = tf.cond( tf.equal(tf.reduce_sum(inputs["encoder_hidden_states"]), 0.0), lambda: calculate_encoder_hidden_state(encoder_embeddings), lambda: use_cache_encoder(), ) # Decoder layer decoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] # Fetching cache_value = all_cache_value[i] cache_key = all_cache_key[i] decoder_embeddings, cache_key, cache_value = layer( [ decoder_embeddings, decoder_attention_mask, decoder_encoder_mask, encoder_hidden_states, ], mode="decoder", cache_key=cache_key, cache_value=cache_value, ) # Updating all_cache_key[i] = cache_key all_cache_value[i] = cache_value decoder_outputs.append(decoder_embeddings) # Stack all layers key and value together # num_layers x batch_size x num_heads x sequence_length x # (hidden_dimension/num_heads) # noqa all_cache_key = tf.stack(all_cache_key, axis=0, name="decoder_all_cache_key") all_cache_value = tf.stack(all_cache_value, axis=0, name="decoder_all_cache_value") # First word of last layer outputs [CLS] # cls_token_tensor = tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(decoder_outputs[-1]) # batch_size x embedding_size # cls_output = self._pooler_layer(cls_token_tensor) # batch_size x sequence_length x embedding_size token_embeddings = decoder_outputs[-1] # MLM Projection if self.use_mlm_layer: token_embeddings = self.mlm_layer(token_embeddings) # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = ( tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) + self._last_logits_bias ) else: # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits) return { "encoder_hidden_states": encoder_hidden_states, "decoder_all_cache_key": all_cache_key, "decoder_all_cache_value": all_cache_value, "token_embeddings": token_embeddings, "token_logits": token_logits, "last_token_logits": last_token_logits, }
def call_training(self, inputs): """Forward Pass for BERT Args: inputs: dict inputs is a dict with keys [`input_ids` , `input_mask`, `input_type_ids`]. These keys might or might not be present based on `mask_mode` and other criterias """ input_ids = inputs["input_ids"] # When `mask_mode` is `causal` , input_mask is not required if self.mask_mode in ["user_defined", "prefix"]: input_mask = inputs["input_mask"] # Default True in BERT if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] sequence_length = tf.shape(input_ids)[1] word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer(tf.range(sequence_length)) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "user_defined": attention_mask = SelfAttentionMask()([embeddings, input_mask]) if self.mask_mode == "prefix": attention_mask = tf.map_fn(prefix_mask, input_mask, dtype=tf.float32) if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) encoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] embeddings, _, _ = layer([embeddings, attention_mask]) embeddings = tf.identity(embeddings, name="token_embeddings_layer_{}".format(i)) encoder_outputs.append(embeddings) # Last layer output has to be normalized in GPT2 encoder_outputs[-1] = self._last_layer_norm(encoder_outputs[-1]) # batch_size x sequence_length x embedding_size token_embeddings = encoder_outputs[-1] # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits) last_token_logits = tf.identity(last_token_logits, name="last_token_logits") result = { "token_embeddings": token_embeddings, "token_logits": token_logits, "last_token_logits": last_token_logits, } if self.return_all_layer_token_embeddings: result["all_layer_token_embeddings"] = encoder_outputs return result
def call_predict(self, inputs): """Inputs will be pass to this method, when is_training = False. The need to cache the past `key` and `value` tensors are \ necessary while predicting, to make the inference/NLG faster in case of AutoRegressive Decoding. """ input_ids_mod = inputs["input_ids"] all_cache_key = inputs["all_cache_key"] all_cache_value = inputs["all_cache_value"] past_length = inputs["past_length"] # Come from kwargs if self.mask_mode in ["user_defined", "prefix"]: input_mask = inputs["input_mask"] if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] # Convert past_length 2D to 1D past_length = tf.squeeze(past_length, 0) # In case of variable batch decoding, we will pad the inputs with -1 # So, we will replace -1 with 0, because -1 \ # is not a valid index in word embeddings # >> input_ids_mod = [[ 1, 5, 7, 8, 10], # 2, 3, -1, -1, -1]] # # >> input_ids = [[1, 5, 7, 8,10], # 2, 3, 0, 0, 0]] input_ids = input_ids_mod * tf.cast(tf.not_equal(input_ids_mod, -1), tf.int32) sequence_length = tf.shape(input_ids)[1] # Asserting tf.assert_equal(tf.shape(all_cache_value)[0], self.num_hidden_layers) # Step 0 of inference. For step0, we do not have valid cache. We pass zero tensor def step_0(input_ids): sequence_length = tf.shape(input_ids)[1] position_embeddings = self._position_embedding_layer(tf.range(sequence_length)) return sequence_length, position_embeddings # From step_1 (autoregressive mode starts) onwards, we need to account for # `past_length` of previous words (inputs + generated) . Due to our logic, # we need to take a transpose of `position_embeddings` in this specific setting def step_other(input_ids): sequence_length = tf.shape(input_ids)[1] # Because past_length varies with batch position_embeddings = self._position_embedding_layer(past_length + sequence_length) position_embeddings = tf.transpose(position_embeddings, [1, 0, 2]) return sequence_length, position_embeddings # Condition to switch functions # if `sum(past_length) = 0` , means no outputs has been generated. \ # the given inputs is the first input sequence_length, positional_embeddings = tf.cond( tf.equal(tf.reduce_sum(past_length), 0), lambda: step_0(input_ids), lambda: step_other(input_ids), ) all_cache_key = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0) ] all_cache_value = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0) ] word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "user_defined": attention_mask = SelfAttentionMask()([embeddings, input_mask]) if self.mask_mode == "prefix": attention_mask = tf.map_fn(prefix_mask, input_mask, fn_output_signature=tf.float32) if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) encoder_outputs = [] # Make all -1 positions to 0 (as -1 represents padding in the input) mask_values = tf.cast(tf.not_equal(input_ids_mod, -1), tf.float32) # We want zero values , where embeddings inputs where 0 (by replacing PAD -1) # So we use the mask and multiply it with embeddings embeddings = embeddings * tf.expand_dims(mask_values, -1) for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] # Fetching cache_value = all_cache_value[i] cache_key = all_cache_key[i] embeddings, cache_key, cache_value = layer( [embeddings, attention_mask], cache_key=cache_key, cache_value=cache_value, ) # Updating all_cache_key[i] = cache_key all_cache_value[i] = cache_value # Mask next layer embedding (PAD positions to 0) embeddings = tf.identity( embeddings * tf.expand_dims(mask_values, -1), name="encoder_outputs_{}".format(i), ) encoder_outputs.append(embeddings) def step_0_gather(past_length, token_embeddings): cache_length = tf.reduce_sum(tf.cast(tf.not_equal(input_ids_mod, -1), tf.int32), axis=1) - 1 # Getting corresponding last token tensor and last token logits last_token_tensor = tf.gather_nd(token_embeddings, tf.expand_dims(cache_length, axis=1), batch_dims=1) past_length = past_length + cache_length return past_length, last_token_tensor def step_other_gather(past_length, token_embeddings): past_length = past_length + sequence_length last_token_tensor = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_embeddings) return past_length, last_token_tensor # batch_size x sequence_length x embedding_size token_embeddings = self._last_layer_norm(encoder_outputs[-1]) # Condition to switch functionsn (When batch_size > 1, # past_length will be different for each entry) # if `sum(past_length) = 0` , means no outputs has been generated. # the given inputs is the first input past_length, last_token_tensor = tf.cond( tf.equal(tf.reduce_sum(past_length), 0), lambda: step_0_gather(past_length, token_embeddings), lambda: step_other_gather(past_length, token_embeddings), ) # token --> vocab ( batch_size x sequence_length x vocab_size) last_token_logits = tf.matmul( last_token_tensor, self.get_embedding_table(), transpose_b=True, name="token_logits", ) # last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits) # Expand dims of past_length back to 2D past_length = tf.expand_dims(past_length, 0, name="past_length") # Stack all layers key and value together # num_layers x batch_size x num_heads x sequence_length x (hidden_dimension/num_heads) all_cache_key = tf.stack(all_cache_key, axis=0, name="all_cache_key") all_cache_value = tf.stack(all_cache_value, axis=0, name="all_cache_value") return { "token_embeddings": token_embeddings, "last_token_logits": last_token_logits, "past_length": past_length, "all_cache_key": all_cache_key, "all_cache_value": all_cache_value, }
def call_decoder_predict(self, inputs): """Inputs will be pass to this method, when is_training = False and is_decoder = True. The need to cache the past `key` and `value` tensors for \ decoders \necessary while predicting, to make the inference/NLG faster in case of AutoRegressive Decoding. """ input_ids = inputs["input_ids"] encoder_hidden_state = inputs["encoder_hidden_states"] decoder_encoder_mask = inputs["decoder_encoder_mask"] all_cache_key = inputs["all_cache_key"] all_cache_value = inputs["all_cache_value"] # When `mask_mode` is `causal` , input_mask is not required # if self.mask_mode in ["user_defined"]: # input_mask = inputs["input_mask"] if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] # sequence_length = tf.shape(input_ids)[1] all_cache_key = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_key, num_or_size_splits=self.num_hidden_layers, axis=0) ] all_cache_value = [ tf.squeeze(item, axis=0) for item in tf.split(all_cache_value, num_or_size_splits=self.num_hidden_layers, axis=0) ] # If decoder is not sharing embeddings word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer( input_type_ids) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) decoder_outputs = [] position_bias = None decoder_encoder_position_bias = None for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] # Fetching cache_value = all_cache_value[i] cache_key = all_cache_key[i] ( embeddings, position_bias, decoder_encoder_position_bias, cache_key, cache_value, ) = layer( [ embeddings, attention_mask, encoder_hidden_state, decoder_encoder_mask, ], position_bias=position_bias, decoder_encoder_position_bias=decoder_encoder_position_bias, cache_key=cache_key, cache_value=cache_value, ) # Updating all_cache_key[i] = cache_key all_cache_value[i] = cache_value decoder_outputs.append(embeddings) # Stack all layers key and value together # num_layers x batch_size x num_heads x sequence_length x (hidden_dimension/num_heads) # noqa all_cache_key = tf.stack(all_cache_key, axis=0, name="all_cache_key") all_cache_value = tf.stack(all_cache_value, axis=0, name="all_cache_value") decoder_outputs[-1] = self._last_layer_norm(decoder_outputs[-1]) # batch_size x sequence_length x embedding_size token_embeddings = self._last_layer_dropout(decoder_outputs[-1]) # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = tf.matmul( token_embeddings, self.get_embedding_table(), transpose_b=True, name="token_logits", ) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])( token_logits) return { "all_cache_key": all_cache_key, "all_cache_value": all_cache_value, "token_embeddings": token_embeddings, "token_logits": token_logits, "last_token_logits": last_token_logits, }
def call_decoder(self, inputs): """Forward Pass for Decoder Args: inputs: dict inputs is a dict with keys [`input_ids` , `input_mask`, `input_type_ids`, \ `encoder_hidden_states`, `decoder_encoder_mask`]. These keys might or might not be present based on `mask_mode` and other criterias """ input_ids = inputs["input_ids"] encoder_output = inputs["encoder_hidden_states"] decoder_encoder_mask = inputs["decoder_encoder_mask"] if self.mask_mode in ["user_defined"]: input_mask = inputs["input_mask"] if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] sequence_length = tf.shape(input_ids)[1] # If decoder is not sharing embeddings if self.initialize_embeddings: word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer( tf.range(sequence_length)) embeddings = embeddings + positional_embeddings else: embeddings = inputs["decoder_embeddings"] # Norm + dropout embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "user_defined": attention_mask = SelfAttentionMask()([embeddings, input_mask]) if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) decoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] embeddings, _key, _value = layer([ embeddings, attention_mask, encoder_output, decoder_encoder_mask ]) decoder_outputs.append(embeddings) # batch_size x sequence_length x embedding_size token_embeddings = decoder_outputs[-1] return { "token_embeddings": token_embeddings, "all_layer_token_embeddings": decoder_outputs, }
def call_training(self, inputs): """Forward Pass for BERT Args: inputs: dict inputs is a dict with keys [`input_ids` , `input_mask`, `input_type_ids`]. These keys might or might not be present based on `mask_mode` and other criterias """ input_ids = inputs["input_ids"] # When `mask_mode` is `causal` , input_mask is not required if self.mask_mode in ["user_defined", "prefix"]: input_mask = inputs["input_mask"] # Default True in BERT if self.use_type_embeddings: input_type_ids = inputs["input_type_ids"] sequence_length = tf.shape(input_ids)[1] word_embeddings = self._embedding_layer(input_ids) embeddings = word_embeddings # Add word_embeddings + position_embeddings + type_embeddings if self.use_type_embeddings: type_embeddings = self._type_embeddings(input_type_ids) embeddings = embeddings + type_embeddings if self.use_positonal_embeddings: positional_embeddings = self._position_embedding_layer( tf.range(sequence_length)) embeddings = embeddings + positional_embeddings # Norm + dropout embeddings = self._embedding_norm(embeddings) embeddings = self._embedding_dropout(embeddings, training=self.use_dropout) # Initialize `attention_mask` as empty list attention_mask = [] if self.mask_mode == "user_defined": attention_mask = SelfAttentionMask()([embeddings, input_mask]) if self.mask_mode == "prefix": attention_mask = tf.map_fn(prefix_mask, input_mask, dtype=tf.float32) if self.mask_mode == "causal": attention_mask = CausalMask()(embeddings) encoder_outputs = [] for i in range(self.num_hidden_layers): layer = self._transformer_layers[i] embeddings, _, _ = layer([embeddings, attention_mask]) encoder_outputs.append(embeddings) # First word of last layer outputs [CLS] cls_token_tensor = tf.keras.layers.Lambda( lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(encoder_outputs[-1]) # batch_size x embedding_size cls_output = self._pooler_layer(cls_token_tensor) # batch_size x sequence_length x embedding_size token_embeddings = encoder_outputs[-1] # unilm has one token_embeddings_extra = self.mlm_layer(token_embeddings) # token --> vocab ( batch_size x sequence_length x vocab_size) token_logits = (tf.matmul( token_embeddings_extra, self.get_embedding_table(), transpose_b=True, name="token_logits", ) + self._last_logits_bias) last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])( token_logits) return { "cls_output": cls_output, "token_embeddings": token_embeddings_extra, "all_layer_token_embeddings": encoder_outputs, "token_logits": token_logits, "last_token_logits": last_token_logits, }