def build_key(self): with tf.compat.v1.variable_scope("embeddings"): input_tensor = self.get_embeddings(self.input_ids, self.segment_ids) self.input_shape = bc.get_shape_list(input_tensor, expected_rank=3) with tf.compat.v1.variable_scope("encoder"): self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.layers_before_key_pooling): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) intermediate_output = tf.reshape(intermediate_output, [ self.batch_size * self.seq_length, self.config.intermediate_size ]) final_output = bc.reshape_from_matrix( prev_output, self.input_shape) self.all_layer_outputs.append(final_output) self.last_intermediate_output = intermediate_output self.last_key_layer = prev_output with tf.compat.v1.variable_scope("mr_key"): key_vectors = bc.dense(self.key_dimension, self.initializer)(intermediate_output) self.debug1 = key_vectors key_vectors = tf.reshape( key_vectors, [self.batch_size, self.seq_length, self.key_dimension]) key_output = self.key_pooling(key_vectors) return key_output
def call(self, input_ids, input_mask, segment_ids): with tf.compat.v1.variable_scope("embeddings"): self.embedding_layer = Embedding2() input_tensor = self.embedding_layer.apply( input_ids, segment_ids, self.config.initializer_range, self.config.vocab_size, self.config.embedding_size, self.config.type_vocab_size, self.config.max_position_embeddings, self.config.hidden_dropout_prob, self.use_one_hot_embeddings) input_tensor = self.embedding_projection(input_tensor) self.embedding_output = input_tensor input_shape = bc.get_shape_list2(input_tensor) batch_size, seq_length, _ = input_shape with tf.compat.v1.variable_scope("encoder"): self.attention_mask = bc.create_attention_mask_from_input_mask2( input_tensor, input_mask) prev_output = bc.reshape_to_matrix(input_tensor) with tf.compat.v1.variable_scope("layer"): intermediate_output, prev_output = self.layer.apply( prev_output, batch_size, seq_length, self.attention_mask) final_output = bc.reshape_from_matrix2(prev_output, input_shape) self.all_layer_outputs.append(final_output) for layer_idx in range(1, self.config.num_hidden_layers): with tf.compat.v1.variable_scope("layer", reuse=True): intermediate_output, prev_output = self.layer.apply( prev_output, batch_size, seq_length, self.attention_mask) final_output = bc.reshape_from_matrix2( prev_output, input_shape) self.all_layer_outputs.append(final_output) return prev_output
def call(self, input_vectors, use_context): # input_vectors : [num_window, hidden_size] batch_size, seq_length, hidden_dim = bc.get_shape_list2(input_vectors) # Add position embedding input_vectors = bc.embedding_postprocessor2( input_tensor=input_vectors, token_type_table=self.token_type_table, full_position_embeddings=self.full_position_embeddings, use_token_type=False, token_type_ids=None, token_type_vocab_size=1, use_position_embeddings=True, max_position_embeddings=self.config.max_num_window, dropout_prob=self.config.hidden_dropout_prob) input_shape = [batch_size, seq_length] attention_mask = tf.ones([batch_size, seq_length, seq_length], tf.int32) * tf.expand_dims(use_context, 2) with tf.compat.v1.variable_scope("mid"): prev_output = bc.reshape_to_matrix(input_vectors) for layer_idx in range(self.n_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.layer_list[ layer_idx].apply(prev_output, batch_size, seq_length, attention_mask) final_output = bc.reshape_from_matrix2( prev_output, input_shape) self.all_layer_outputs.append(final_output) return prev_output
def apply_3d(self, input_tensor, batch_size, seq_length, attention_mask): input_shape = bc.get_shape_list2(input_tensor) input_tensor = bc.reshape_to_matrix(input_tensor) intermediate_output, layer_output = self.apply(input_tensor, batch_size, seq_length, attention_mask) return bc.reshape_from_matrix2(layer_output, input_shape)
def build_by_attention(self, key): hidden_size = self.config.hidden_size with tf.compat.v1.variable_scope("embeddings"): lexical_tensor = self.get_lexical_lookup() self.embedding_output = self.embedding_postprocessor( d_input_ids=self.input_ids, input_tensor=lexical_tensor, use_token_type=True, token_type_ids=self.segment_ids, token_type_vocab_size=self.config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=self.config.initializer_range, max_position_embeddings=self.config.max_position_embeddings, dropout_prob=self.config.hidden_dropout_prob) input_tensor = self.embedding_output #[ def_per_batch, seq_length, hidden_size] with tf.compat.v1.variable_scope("encoder"): num_key_tokens = self.ssdr_config.num_key_tokens project_dim = hidden_size * num_key_tokens raw_key = bc.dense(project_dim, self.initializer)(key) key_tokens = tf.reshape( raw_key, [self.batch_size, num_key_tokens, hidden_size]) input_tensor = tf.concat([key_tokens, input_tensor], axis=1) input_shape = bc.get_shape_list(input_tensor, expected_rank=3) mask_for_key = tf.ones([self.batch_size, num_key_tokens], dtype=tf.int64) self.input_mask = tf.cast(self.input_mask, tf.int64) self.input_mask = tf.concat([mask_for_key, self.input_mask], axis=1) self.seq_length = self.seq_length + num_key_tokens self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.ssdr_config.num_hidden_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) self.all_layer_outputs.append(prev_output) final_output = bc.reshape_from_matrix(prev_output, input_shape) self.scores = bc.dense(1, self.initializer)(final_output[:, 0, :]) if self.ssdr_config.info_pooling_method == "first_tokens": self.info_output = final_output[:, :num_key_tokens, :] elif self.ssdr_config.info_pooling_method == "max_pooling": self.info_output = tf.reduce_max(final_output, axis=1) return self.scores, self.info_output
def apply_topic_vector(self, input_tensor, topic_ids, layer_idx): if layer_idx == 0: return self.add_topic_vector(input_tensor, topic_ids) else: input_tensor = bc.reshape_from_matrix2(input_tensor, self.input_shape) input_tensor = input_tensor[:, -self.topic_emb_len, :] input_tensor = tf.concat( [input_tensor, self.topic_tensor[layer_idx]], axis=1) input_tensor = bc.reshape_to_matrix(input_tensor) return input_tensor
def call(self, input_vectors, attention_mask): prev_output = input_vectors input_shape = bc.get_shape_list2(input_vectors) batch_size, seq_length, _ = input_shape prev_output = bc.reshape_to_matrix(prev_output) for layer_idx in range(self.n_layers): with tf.compat.v1.variable_scope( "layer_%d" % (layer_idx + self.layer_idx_base)): layer = self.layer_list[layer_idx] intermediate_output, prev_output = layer.apply( prev_output, batch_size, seq_length, attention_mask) final_output = bc.reshape_from_matrix2(prev_output, input_shape) self.all_layer_outputs.append(final_output) return prev_output
def call(self, input_ids, input_mask, segment_ids, topic_ids): with tf.compat.v1.variable_scope("embeddings"): self.embedding_layer = Embedding(self.config, self.use_one_hot_embeddings) input_tensor = self.embedding_layer.apply(input_ids, segment_ids) self.embedding_output = input_tensor input_mask = self.extend_input_mask(input_mask) topic_tensor, _ = bc.embedding_lookup2(topic_ids, self.n_topics, self.topic_embedding, self.topic_embedding_size, self.use_one_hot_embeddings) self.topic_tensor = tf.reshape( topic_tensor, [-1, self.topic_emb_len, self.hidden_size]) input_tensor = tf.concat([input_tensor, self.topic_tensor], axis=1) input_shape = bc.get_shape_list2(input_tensor) batch_size, seq_length, _ = input_shape with tf.compat.v1.variable_scope("encoder"): self.attention_mask = bc.create_attention_mask_from_input_mask2( input_tensor, input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.n_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer = self.layer_list[layer_idx] intermediate_output, prev_output = layer.apply( prev_output, batch_size, seq_length, self.attention_mask) final_output = bc.reshape_from_matrix2( prev_output, input_shape) self.all_layer_outputs.append(final_output) self.embedding_table = self.embedding_layer.embedding_table self.sequence_output = final_output[:, :-self.topic_emb_len] self.pooled_output = mimic_pooling(self.sequence_output, self.config.hidden_size, self.config.initializer_range) return self.sequence_output
def build(self, value_out, locations): with tf.compat.v1.variable_scope("embeddings"): input_tensor = self.get_embeddings(self.input_ids, self.segment_ids) self.input_shape = bc.get_shape_list(input_tensor, expected_rank=3) with tf.compat.v1.variable_scope("encoder"): self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) prev_output = tf.tensor_scatter_nd_update(prev_output, locations, value_out) for layer_idx in range(self.config.num_hidden_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) final_output = bc.reshape_from_matrix( prev_output, self.input_shape) self.all_layer_outputs.append(final_output) return self.all_layer_outputs
def call(self, input_ids, input_mask, segment_ids): with tf.compat.v1.variable_scope("embeddings"): self.embedding_layer = Embedding(self.config, self.use_one_hot_embeddings) input_tensor = self.embedding_layer.apply(input_ids, segment_ids) self.embedding_output = input_tensor input_shape = bc.get_shape_list2(input_tensor) batch_size, seq_length, _ = input_shape with tf.compat.v1.variable_scope("lower"): self.attention_mask = bc.create_attention_mask_from_input_mask2( input_tensor, input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.n_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer = self.layer_list[layer_idx] intermediate_output, prev_output = layer.apply( prev_output, batch_size, seq_length, self.attention_mask) final_output = bc.reshape_from_matrix2( prev_output, input_shape) self.all_layer_outputs.append(final_output) return prev_output
def build(self): with tf.compat.v1.variable_scope("dict"): with tf.compat.v1.variable_scope("embeddings"): input_tensor = self.get_embeddings(self.input_ids, self.segment_ids) with tf.compat.v1.variable_scope("encoder"): num_key_tokens = self.ssdr_config.num_key_tokens input_shape = bc.get_shape_list(input_tensor, expected_rank=3) mask_for_key = tf.ones([self.batch_size, num_key_tokens], dtype=tf.int64) self.input_mask = tf.cast(self.input_mask, tf.int64) self.input_mask = tf.concat([mask_for_key, self.input_mask], axis=1) self.seq_length = self.seq_length + num_key_tokens self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.ssdr_config.num_hidden_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) self.all_layer_outputs.append(prev_output) final_output = bc.reshape_from_matrix(prev_output, input_shape) self.scores = bc.dense(1, self.initializer)(final_output[:, 0, :]) if self.ssdr_config.info_pooling_method == "first_tokens": self.info_output = final_output[:, :num_key_tokens, :] elif self.ssdr_config.info_pooling_method == "max_pooling": self.info_output = tf.reduce_max(final_output, axis=1) return self.scores, self.info_output
def transformer_model(input_tensor, attention_mask=None, input_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, mr_num_route=10, intermediate_size=3072, intermediate_act_fn=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, is_training=True, do_return_all_layers=False): """Multi-headed, multi-layer Transformer from "Attention is All You Need". This is almost an exact implementation of the original Transformer encoder. See the original paper: https://arxiv.org/abs/1706.03762 Also see: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py Args: input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, seq_length], with 1 for positions that can be attended to and 0 in positions that should not be. hidden_size: int. Hidden size of the Transformer. num_hidden_layers: int. Number of layers (blocks) in the Transformer. num_attention_heads: int. Number of attention heads in the Transformer. intermediate_size: int. The size of the "intermediate" (a.k.a., feed forward) layer. intermediate_act_fn: function. The non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: float. Dropout probability for the hidden layers. attention_probs_dropout_prob: float. Dropout probability of the attention probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). do_return_all_layers: Whether to also return all layers or just the final layer. Returns: float Tensor of shape [batch_size, seq_length, hidden_size], the final hidden layer of the Transformer. Raises: ValueError: A Tensor shape or parameter is invalid. """ if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) input_shape = get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] input_width = input_shape[2] initializer = create_initializer(initializer_range) ext_tensor = tf.compat.v1.get_variable("ext_tensor", shape=[num_hidden_layers, mr_num_route, EXT_SIZE ,hidden_size], initializer=initializer, ) ext_tensor_inter = tf.compat.v1.get_variable("ext_tensor_inter", shape=[num_hidden_layers, mr_num_route, intermediate_size], initializer=initializer, ) # The Transformer performs sum residuals on all layers so the input needs # to be the same as the hidden size. if input_width != hidden_size: raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % (input_width, hidden_size)) # We keep the representation as a 2D tensor to avoid re-shaping it back and # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. prev_output = reshape_to_matrix(input_tensor) def is_mr_layer(layer_idx): if layer_idx > 1: return True else: return False all_layer_outputs = [] for layer_idx in range(num_hidden_layers): if not is_mr_layer(layer_idx): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer_input = prev_output with tf.compat.v1.variable_scope("attention"): attention_heads = [] with tf.compat.v1.variable_scope("self"): attention_head = attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.compat.v1.variable_scope("output"): attention_output = dense(hidden_size, initializer)(attention_output) attention_output = dropout(attention_output, hidden_dropout_prob) attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.compat.v1.variable_scope("intermediate"): intermediate_output = dense(intermediate_size, initializer, activation=intermediate_act_fn)(attention_output) # Down-project back to `hidden_size` then add the residual. with tf.compat.v1.variable_scope("output"): layer_output = dense(hidden_size, initializer)(intermediate_output) layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output all_layer_outputs.append(layer_output) with tf.compat.v1.variable_scope("mr_key"): key_output = tf.keras.layers.Dense( mr_num_route, kernel_initializer=create_initializer(initializer_range))(intermediate_output) key_output = dropout(key_output, hidden_dropout_prob) if is_training: key = tf.random.categorical(key_output, 1) # [batch_size, 1] key = tf.reshape(key, [-1]) else: key = tf.math.argmax(input=key_output, axis=1) else: # Case MR layer with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer_input = prev_output ext_slice = tf.gather(ext_tensor[layer_idx], key) ext_interm_slice = tf.gather(ext_tensor_inter[layer_idx], key) print("ext_slice (batch*seq, ", ext_slice.shape) with tf.compat.v1.variable_scope("attention"): attention_heads = [] with tf.compat.v1.variable_scope("self"): attention_head = attention_layer_w_ext( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, ext_slice=ext_slice, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_head = attention_head + ext_slice[:,EXT_ATT_OUT,:] attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.compat.v1.variable_scope("output"): attention_output = dense(hidden_size, initializer)(attention_output) attention_output = dropout(attention_output, hidden_dropout_prob) attention_output = attention_output + ext_slice[:,EXT_ATT_PROJ,:] attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.compat.v1.variable_scope("intermediate"): intermediate_output = dense(intermediate_size, initializer, activation=intermediate_act_fn)(attention_output) intermediate_output = ext_interm_slice + intermediate_output # Down-project back to `hidden_size` then add the residual. with tf.compat.v1.variable_scope("output"): layer_output = dense(hidden_size, initializer)(intermediate_output) layer_output = layer_output + ext_slice[:, EXT_LAYER_OUT,:] layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output all_layer_outputs.append(layer_output) if do_return_all_layers: final_outputs = [] for layer_output in all_layer_outputs: final_output = reshape_from_matrix(layer_output, input_shape) final_outputs.append(final_output) return final_outputs, key else: final_output = reshape_from_matrix(prev_output, input_shape) return final_output, key
def attention_layer_w_ext(from_tensor, to_tensor, attention_mask=None, num_attention_heads=1, size_per_head=512, ext_slice=None, # [Num_tokens, n_items, hidden_dim] query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities. initializer_range: float. Range of the weight initializer. do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]. If False, the output will be of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is true, this will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]). Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(a=output_tensor, perm=[0, 2, 1, 3]) return output_tensor from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` from_tensor_2d = reshape_to_matrix(from_tensor) to_tensor_2d = reshape_to_matrix(to_tensor) def get_ext_slice(idx): return ext_slice[:, idx, :] print("from_tensor_2d ", from_tensor_2d.shape) query_in = from_tensor_2d + get_ext_slice(EXT_QUERY_IN) query_in = from_tensor_2d # `query_layer` = [B*F, N*H] query_layer = tf.keras.layers.Dense( num_attention_heads * size_per_head, activation=query_act, name="query", kernel_initializer=create_initializer(initializer_range))(query_in) query_layer = query_layer + get_ext_slice(EXT_QUERY_OUT) key_in = to_tensor_2d key_in = to_tensor_2d + get_ext_slice(EXT_KEY_IN) # `key_layer` = [B*T, N*H] key_layer = tf.keras.layers.Dense( num_attention_heads * size_per_head, activation=key_act, name="key", kernel_initializer=create_initializer(initializer_range))(key_in) key_layer = key_layer + get_ext_slice(EXT_KEY_OUT) value_in = to_tensor_2d value_in = to_tensor_2d + get_ext_slice(EXT_VALUE_IN) # `value_layer` = [B*T, N*H] value_layer = tf.keras.layers.Dense( num_attention_heads * size_per_head, activation=value_act, name="value", kernel_initializer=create_initializer(initializer_range))(value_in) value_layer = value_layer + get_ext_slice(EXT_VALUE_OUT) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = dropout(attention_probs, attention_probs_dropout_prob) # `value_layer` = [B, T, N, H] value_layer = tf.reshape( value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]) # `value_layer` = [B, N, T, H] value_layer = tf.transpose(a=value_layer, perm=[0, 2, 1, 3]) # `context_layer` = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # `context_layer` = [B, F, N, H] context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3]) if do_return_2d_tensor: # `context_layer` = [B*F, N*V] context_layer = tf.reshape( context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head]) else: # `context_layer` = [B, F, N*V] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head]) return context_layer