def build_model(self): with tf.variable_scope("inferring_module"): rdim = 768 update_num = 2 batch_size = tf.shape(self.sent1)[0] dim = self.sent1.get_shape().as_list()[-1] sr_cell = GRUCell(num_units=rdim, activation=tf.nn.relu) r_cell = sr_cell tri_cell = TriangularCell(num_units=rdim, r_cell=r_cell, sent1=self.sent1, sent2=self.sent2, sent3=self.sent3, sent1_length=39, sent2_length=110, sent3_length=152, dim=dim, use_bias=False, activation=tf.nn.relu, sent1_mask=self.sent1_mask, sent2_mask=self.sent2_mask, sent3_mask=self.sent3_mask, initializer=None, dtype=tf.float32) fake_input = tf.tile(tf.expand_dims(self.mark0, axis=1), [1, update_num, 1]) self.init_state = tri_cell.zero_state(batch_size=batch_size, dtype=tf.float32) self.double_output, last_state = dynamic_rnn( cell=tri_cell, inputs=fake_input, initial_state=self.init_state) r1_output, r2_output, r3_output = last_state[3:] # (B, dim) temp13 = tf.concat([r1_output, r3_output, r1_output * r3_output], axis=1) temp23 = tf.concat([r2_output, r3_output, r2_output * r3_output], axis=1) temp13 = dropout(temp13, self.dropout_rate) temp23 = dropout(temp23, self.dropout_rate) r13 = tf.layers.dense(temp13, 768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) r23 = tf.layers.dense(temp23, 768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) temp = tf.concat([self.mark0, r13, r23], axis=1) refer_output = tf.layers.dense( temp, 768, activation=None, kernel_initializer=create_initializer(0.02)) return refer_output
def build_model(self): temp = self.all_sequence[-1] with tf.variable_scope("lstm"): temp = dropout(temp, 0.1) seq_len = tf.reduce_sum(self.sent_mask, axis=1) gru_fw = GRUCell(num_units=768, activation=tf.tanh) gru_bw = GRUCell(num_units=768, activation=tf.tanh) outputs, output_states = bidirectional_dynamic_rnn( gru_fw, gru_bw, temp, sequence_length=seq_len, dtype=tf.float32) gru_output = tf.concat(outputs, axis=2) # gru_output = dropout(gru_output, 0.1) gru_output = tf.layers.dense(gru_output, units=768, kernel_initializer=create_initializer(0.02)) gru_output = dropout(gru_output, 0.1) outputs = layer_norm(gru_output + temp) in_outputs = tf.layers.dense(outputs, units=768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) layer_output = tf.layers.dense(in_outputs, 768, kernel_initializer=create_initializer(0.02)) layer_output = dropout(layer_output, 0.1) layer_output = layer_norm(layer_output + outputs) return layer_output
def build_model(self): from tensorflow.python.keras.layers import Dense, Dot dim = self.sent1.get_shape().as_list()[-1] temp_W = tf.layers.dense(self.sent2, dim, name="dense") # (B, L2, dim) temp_W = Dot(axes=[2, 2])([self.sent1, temp_W]) # (B, L1, L2) if self.sent1_mask is not None: s1_mask_exp = tf.expand_dims(self.sent1_mask, axis=2) # (B, L1, 1) s2_mask_exp = tf.expand_dims(self.sent2_mask, axis=1) # (B, 1, L2) temp_W1 = temp_W - (1 - s1_mask_exp) * 1e20 temp_W2 = temp_W - (1 - s2_mask_exp) * 1e20 else: temp_W1 = temp_W temp_W2 = temp_W W1 = tf.nn.softmax(temp_W1, axis=1) W2 = tf.nn.softmax(temp_W2, axis=2) M1 = Dot(axes=[2, 1])([W2, self.sent2]) M2 = Dot(axes=[2, 1])([W1, self.sent1]) s1_cat = tf.concat([M2 - self.sent2, M2 * self.sent2], axis=-1) s2_cat = tf.concat([M1 - self.sent1, M1 * self.sent1], axis=-1) S1 = tf.layers.dense(s1_cat, dim, activation=tf.nn.relu, name="cat_dense") S2 = tf.layers.dense(s2_cat, dim, activation=tf.nn.relu, name="cat_dense", reuse=True) if self.is_training: S1 = dropout(S1, dropout_prob=0.1) S1 = dropout(S1, dropout_prob=0.1) if self.sent1_mask is not None: S2 = S2 * tf.expand_dims(self.sent1_mask, axis=2) S1 = S1 * tf.expand_dims(self.sent2_mask, axis=2) C1 = tf.reduce_max(S1, axis=1) C2 = tf.reduce_max(S2, axis=1) C_cat = tf.concat([C1, C2], axis=1) return gelu(tf.layers.dense(C_cat, dim))
def multi_hop(mark, all_sent, seq_len, gru_layer: BiGRU, dropout_rate, is_first=False): length = all_sent.get_shape().as_list()[1] rdim = mark.get_shape().as_list()[-1] exp_mark = tf.tile(tf.expand_dims(mark, axis=1), [1, length, 1]) gru_output = gru_layer(tf.concat([all_sent, exp_mark], axis=2), seq_len) gru_vec = tf.reduce_max(gru_output, axis=1) gru_vec = dropout(gru_vec, dropout_rate) if is_first: trans = "trans" else: trans = "trans1" gru_vec = tf.layers.dense(gru_vec, rdim, activation=tf.tanh, name=trans, kernel_initializer=create_initializer(0.02)) gate = tf.layers.dense(tf.concat([gru_vec, mark], axis=1), rdim, activation=tf.sigmoid, name="gate", kernel_initializer=create_initializer(0.02)) refer_output = mark * gate + (1 - gate) * gru_vec return refer_output, gru_output
def build_model(self): with tf.variable_scope("inferring_module"), tf.device("/device:GPU:0"): rdim = 768 update_num = 3 batch_size = tf.shape(self.sent1)[0] dim = self.sent1.get_shape().as_list()[-1] gru_layer = BiGRU(num_layers=1, num_units=rdim, batch_size=batch_size, input_size=dim, keep_prob=0.9, is_train=self.is_training, activation=tf.nn.tanh) seq_len = tf.reduce_sum(self.input_mask, axis=1) gru_output = gru_layer(self.all_sent, seq_len=seq_len) with tf.variable_scope("att"): all_seq_len = self.all_sent.get_shape().as_list()[1] cls = tf.tile(tf.expand_dims(self.mark0, axis=1), [1, all_seq_len, 1]) cat_att = tf.concat([cls, gru_output], axis=2) res = tf.layers.dense(cat_att, units=512, activation=tf.nn.relu) res = tf.layers.dense(res, units=1, use_bias=False) res_mask = tf.expand_dims(tf.cast(self.input_mask, tf.float32), axis=2) res = res - (1 - res_mask) * 10000.0 alpha = tf.nn.softmax(res, 1) gru_vec = tf.reduce_sum(alpha * gru_output, axis=1) # gru_vec = dropout(gru_vec, self.dropout_rate) gru_vec = tf.layers.dense( gru_vec, 768, activation=gelu, kernel_initializer=create_initializer(0.02)) gru_vec = dropout(gru_vec, self.dropout_rate) gru_vec = layer_norm(gru_vec + self.mark0) gru_vec = tf.layers.dense( gru_vec, 768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) # gate = tf.layers.dense(tf.concat([gru_vec, self.mark0], axis=1), # rdim, activation=tf.sigmoid, # kernel_initializer=create_initializer(0.02)) # with tf.variable_scope("merge"): # # refer_output = self.mark0 * gate + (1 - gate) * gru_vec # vec_cat = tf.concat([self.mark0, gru_vec], axis=1) # vec_cat = dropout(vec_cat, self.dropout_rate) # pooled_output = tf.layers.dense(vec_cat, 768, # activation=tf.tanh, # kernel_initializer=create_initializer(0.02)) return gru_vec
def build_model(self): with tf.variable_scope("inferring_module"), tf.device("/device:GPU:0"): rdim = 768 update_num = 3 batch_size = tf.shape(self.sent1)[0] dim = self.sent1.get_shape().as_list()[-1] gru_layer = BiGRU(num_layers=1, num_units=rdim, batch_size=batch_size, input_size=dim, keep_prob=0.9, is_train=self.is_training, activation=tf.nn.relu) seq_len = tf.reduce_sum(self.input_mask, axis=1) gru_output = gru_layer(self.all_sent, seq_len=seq_len) gru_vec = tf.reduce_max(gru_output, axis=1) gru_vec = dropout(gru_vec, self.dropout_rate) gru_vec = tf.layers.dense( gru_vec, 768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) # gate = tf.layers.dense(tf.concat([gru_vec, self.mark0], axis=1), # rdim, activation=tf.sigmoid, # kernel_initializer=create_initializer(0.02)) with tf.variable_scope("merge"): # refer_output = self.mark0 * gate + (1 - gate) * gru_vec vec_cat = tf.concat([self.mark0, gru_vec], axis=1) vec_cat = dropout(vec_cat, self.dropout_rate) pooled_output = tf.layers.dense( vec_cat, 768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) return pooled_output
def build_model(self): from layers.ParallelInfo import TextCNN, RNNExtract, InteractionExtract, SingleSentenceExtract with tf.variable_scope("inferring_module"), tf.device("/device:GPU:0"): rdim = 768 batch_size = tf.shape(self.sent1)[0] sent_length = self.all_sent.get_shape().as_list()[1] dim = self.sent1.get_shape().as_list()[-1] # text_cnn = TextCNN(rdim, [1, 2, 3, 4, 5, 7], 50) rnn_ext = RNNExtract(num_units=rdim, batch_size=batch_size, input_size=dim, keep_prob=0.9, is_train=self.is_training) # img_ext = InteractionExtract(num_units=256, seq_len=sent_length) # text_vec = text_cnn(self.all_sent, mask=self.input_mask) rnn_vec = rnn_ext(self.all_sent, input_mask=self.input_mask) # img_vec = img_ext(self.all_sent, self.sent1_mask, self.sent2_mask, self.dropout_rate) temp_res = tf.concat([rnn_vec, self.mark0], axis=1) # temp_res = tf.reshape(temp_res, [-1, 3, dim]) # alpha = tf.layers.dense(temp_res, 1, activation=tf.tanh) # alpha = tf.nn.softmax(alpha, axis=1) temp_res = dropout(temp_res, self.dropout_rate) # gate0 = tf.layers.dense(temp_res, units=rdim, # activation=tf.nn.sigmoid, # kernel_initializer=create_initializer(0.02)) # gate1 = tf.layers.dense(temp_res, units=rdim, # activation=tf.nn.sigmoid, # kernel_initializer=create_initializer(0.02)) # gate2 = tf.layers.dense(temp_res, units=rdim, # activation=tf.nn.sigmoid, # kernel_initializer=create_initializer(0.02)) # gate = tf.concat([gate0, gate1, gate2], axis=1) # res_vec = tf.reshape(temp_res, [-1, 3, rdim]) # gate = tf.nn.softmax(tf.reshape(gate, [-1, 3, rdim]), axis=1) # score = transformer_model(res_vec, hidden_size=rdim, num_hidden_layers=1, # num_attention_heads=1, intermediate_size=rdim) # gate = tf.nn.softmax(score, axis=1) # return self.mark0 * gate + (1 - gate) * rnn_vec return tf.layers.dense(temp_res, 768, tf.tanh, kernel_initializer=create_initializer(0.02))
def build_model(self): with tf.variable_scope("inferring_module"): rdim = 768 update_num = 1 batch_size = tf.shape(self.sent1)[0] dim = self.sent1.get_shape().as_list()[-1] sr_cell = GRUCell(num_units=rdim, activation=tf.nn.relu) sent_cell = r_cell = sr_cell tri_cell = DoubleCell( num_units=rdim, sent_cell=sent_cell, r_cell=r_cell, sent1=self.sent1, sent2=self.sent2, # sent1_length=self.Q_maxlen, # sent2_length=self.C_maxlen, dim=dim, use_bias=False, activation=tf.nn.tanh, sent1_mask=self.sent1_mask, sent2_mask=self.sent2_mask, initializer=None, dtype=tf.float32) fake_input = tf.tile(tf.expand_dims(self.mark0, axis=1), [1, update_num, 1]) self.init_state = tri_cell.zero_state(batch_size=batch_size, dtype=tf.float32) self.double_output, last_state = dynamic_rnn( cell=tri_cell, inputs=fake_input, initial_state=self.init_state) refer_output = last_state[2] # (B, dim) temp = tf.concat([refer_output, self.mark0], axis=1) temp = dropout(temp, self.dropout_rate) gate = tf.layers.dense(temp, 768, activation=tf.sigmoid, kernel_initializer=create_initializer(0.02)) return refer_output * (1 - gate) + gate * self.mark0
def build_model(self): from layers.ParallelInfo import TextCNN, RNNExtract, InteractionExtract, SingleSentenceExtract with tf.variable_scope("inferring_module"), tf.device("/device:GPU:0"): rdim = 768 batch_size = tf.shape(self.sent1)[0] sent_length = self.all_sent.get_shape().as_list()[1] update_num = 3 dim = self.sent1.get_shape().as_list()[-1] gru_layer = BiGRU(num_layers=1, num_units=rdim, batch_size=batch_size, input_size=dim, keep_prob=0.9, is_train=self.is_training, activation=tf.nn.tanh) seq_len = tf.reduce_sum(self.input_mask, axis=1) gru_output = gru_layer(self.all_sent, seq_len=seq_len) with tf.variable_scope("att"): all_seq_len = self.all_sent.get_shape().as_list()[1] cls = tf.tile(tf.expand_dims(self.mark0, axis=1), [1, all_seq_len, 1]) cat_att = tf.concat([cls, gru_output], axis=2) res = tf.layers.dense(cat_att, units=512, activation=tf.nn.relu) res = tf.layers.dense(res, units=1, use_bias=False) res_mask = tf.expand_dims(tf.cast(self.input_mask, tf.float32), axis=2) res = res - (1 - res_mask) * 10000.0 alpha = tf.nn.softmax(res, 1) gru_vec = tf.reduce_sum(alpha * gru_output, axis=1) # gru_vec = dropout(gru_vec, self.dropout_rate) gru_vec = tf.layers.dense( gru_vec, 768, activation=gelu, kernel_initializer=create_initializer(0.02)) gru_vec = dropout(gru_vec, self.dropout_rate) gru_vec = layer_norm(gru_vec + self.mark0) gru_vec = tf.layers.dense( gru_vec, 768, activation=tf.tanh, kernel_initializer=create_initializer(0.02)) text_cnn = TextCNN(2 * rdim, [1, 2, 3, 4, 5, 7], 128) img_ext = InteractionExtract(num_units=256, seq_len=sent_length) text_vec = text_cnn(gru_output, mask=self.input_mask) # rnn_vec, rnn_att = rnn_ext(self.all_sent, input_mask=self.input_mask, mark0=self.mark0) img_vec = img_ext(gru_output, self.sent1_mask, self.sent2_mask, self.dropout_rate) temp_res = tf.concat([img_vec, gru_vec, text_vec], axis=1) return tf.layers.dense(temp_res, 768, tf.tanh, kernel_initializer=create_initializer(0.02))
def attention_fusion_layer(bert_config, input_tensor, input_ids, input_mask, source_input_tensor, source_input_ids, source_input_mask, is_training=True, scope=None): ''' Attention Fusion Layer for merging source representation and target representation. ''' # universal shapes input_tensor_shape = modeling.get_shape_list(input_tensor, expected_rank=3) batch_size = input_tensor_shape[0] seq_length = input_tensor_shape[1] hidden_size = input_tensor_shape[2] source_input_tensor_shape = modeling.get_shape_list(source_input_tensor, expected_rank=3) source_seq_length = source_input_tensor_shape[1] source_hidden_size = source_input_tensor_shape[2] # universal parameters UNIVERSAL_DROPOUT_RATE = 0.1 if not is_training: UNIVERSAL_DROPOUT_RATE = 0 # we disable dropout when predicting UNIVERSAL_INIT_RANGE = bert_config.initializer_range NUM_ATTENTION_HEAD = bert_config.num_attention_heads # attention fusion module with tf.variable_scope(scope, default_name="attention_fusion"): ATTENTION_HEAD_SIZE = int(source_hidden_size / NUM_ATTENTION_HEAD) with tf.variable_scope("attention"): source_attended_repr = self_attention_layer( from_tensor=input_tensor, to_tensor=source_input_tensor, attention_mask=modeling.create_attention_mask_from_input_mask( input_ids, source_input_mask), num_attention_heads=NUM_ATTENTION_HEAD, size_per_head=ATTENTION_HEAD_SIZE, attention_probs_dropout_prob=UNIVERSAL_DROPOUT_RATE, initializer_range=UNIVERSAL_INIT_RANGE, do_return_2d_tensor=False, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=source_seq_length, self_adaptive=True) with tf.variable_scope("transform"): source_attended_repr = tf.layers.dense( source_attended_repr, source_hidden_size, kernel_initializer=modeling.create_initializer( UNIVERSAL_INIT_RANGE)) source_attended_repr = modeling.dropout(source_attended_repr, UNIVERSAL_DROPOUT_RATE) source_attended_repr = modeling.layer_norm(source_attended_repr + source_input_tensor) final_output = tf.concat([input_tensor, source_attended_repr], axis=-1) return final_output
def self_attention_layer(from_tensor, to_tensor, self_adaptive=True, attention_mask=None, num_attention_heads=1, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities. initializer_range: float. Range of the weight initializer. do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]. If False, the output will be of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is true, this will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]). Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor from_shape = modeling.get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = modeling.get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` from_tensor_2d = modeling.reshape_to_matrix(from_tensor) to_tensor_2d = modeling.reshape_to_matrix(to_tensor) # `query_layer` = [B*F, N*H] raw_query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, name="query", kernel_initializer=modeling.create_initializer(initializer_range)) # `key_layer` = [B*T, N*H] raw_key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, name="key", kernel_initializer=modeling.create_initializer(initializer_range)) # `value_layer` = [B*T, N*H] raw_value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, name="value", kernel_initializer=modeling.create_initializer(initializer_range)) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(raw_query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(raw_key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # self-interactive attention (alpha version) # [F, F_sm] x [F, T] x [T_sm, T] => [F, T] if self_adaptive: # `left_matrix` = [B, N, F, F_sm] left_matrix = tf.matmul(query_layer, query_layer, transpose_b=True) left_matrix = tf.nn.softmax(left_matrix) # `right_matrix` = [B, N, T_sm, T] right_matrix = tf.matmul(key_layer, key_layer, transpose_b=True) right_matrix = tf.nn.softmax(right_matrix) right_matrix = tf.transpose(right_matrix, [0, 1, 3, 2]) # `left_product` = [B, N, F, F_sm] x [B, N, F, T] left_product = tf.matmul(left_matrix, attention_scores) # `attention_scores` = [B, N, F, T] x [B, N, T_sm, T] attention_scores = tf.matmul(left_product, right_matrix) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = modeling.dropout(attention_probs, attention_probs_dropout_prob) # `value_layer` = [B, T, N, H] value_layer = tf.reshape( raw_value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]) # `value_layer` = [B, N, T, H] value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) # `context_layer` = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # `context_layer` = [B, F, N, H] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: # `context_layer` = [B*F, N*H] context_layer = tf.reshape(context_layer, [ batch_size * from_seq_length, num_attention_heads * size_per_head ]) else: # `context_layer` = [B, F, N*H] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head]) return context_layer #, raw_query_layer, raw_value_layer