def _project_and_split(): if fuse_qkv == True: keys, values = tf.split(tf.layers.conv1d( mem, decoder_args.hidden_dim * 2, 1, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type)), 2, axis=2) else: keys = tf.layers.conv1d( mem, decoder_args.hidden_dim, 1, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type)) values = tf.layers.conv1d( mem, decoder_args.hidden_dim, 1, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type), name="value") keys = tf.reshape(keys, [ tf.shape(keys)[0], tf.shape(keys)[1], decoder_args.head_num, decoder_args.size_per_head ]) keys = tf.transpose(keys, [0, 2, 1, 3]) values = tf.reshape(values, [ tf.shape(values)[0], tf.shape(values)[1], decoder_args.head_num, decoder_args.size_per_head ]) values = tf.transpose(values, [0, 2, 1, 3]) return keys, values
def tf_decoder(decoder_args, inputs, memory, memory_sequence_length, step, cache=None): ''' Run the decoder transformer layer by TensorFlow. Args: decoder_args: The arguments for decoder. The details are in the class "TransformerArgument" of common.py inputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. The inputs tensor of encoder. The rank must be 3. memory: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. The results of encoder transformer layer. The rank must be 3. Note that it must be extended by beam_width times memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. The lenght of each sentence of results of encoder. Note that it must be extended by beam_width times step: A tf.Tensor with tf.int type. The current step in the translation process. cache: A dict. The cache space to store the keys and values of attention layers. Outputs: outputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. The results of decoder. ''' k_init_range = decoder_args.kernel_init_range b_init_range = decoder_args.bias_init_range data_type = decoder_args.dtype fuse_qkv = decoder_args.fuse_qkv hidden_dim = decoder_args.hidden_dim memory_mask = None # has something if memory is not None and not tf.contrib.framework.nest.is_sequence(memory): memory = (memory,) if memory_sequence_length is not None: if not tf.contrib.framework.nest.is_sequence(memory_sequence_length): memory_sequence_length = (memory_sequence_length,) memory_mask = [ build_sequence_mask( length, num_heads=decoder_args.head_num, maximum_length=tf.shape(m)[1], data_type=data_type) for m, length in zip(memory, memory_sequence_length)] for l in range(decoder_args.num_layer): layer_name = "layer_{}".format(l) layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("masked_multi_head"): norm_inputs = norm(inputs) if fuse_qkv == True: queries, keys, values = tf.split( tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim * 3, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)), 3, axis=2) else: ''' This progress wants to prevent a addictional tf.concat to concat the q, k, v kernels for decoder op becuase the concat bring large overhead for small batch size. ''' queries = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) keys = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type), name="key") values = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type), name="value") keys = tf.reshape(keys, [tf.shape(keys)[0], 1, decoder_args.head_num, decoder_args.size_per_head]) keys = tf.transpose(keys, [0, 2, 1, 3]) values = tf.reshape(values, [tf.shape(values)[0], 1, decoder_args.head_num, decoder_args.size_per_head]) values = tf.transpose(values, [0, 2, 1, 3]) keys = tf.concat([layer_cache["self_keys"], keys], axis=2) values = tf.concat([layer_cache["self_values"], values], axis=2) layer_cache["self_keys"] = keys layer_cache["self_values"] = values queries = tf.reshape(queries, [tf.shape(queries)[0], 1, decoder_args.head_num, decoder_args.size_per_head]) queries = tf.transpose(queries, [0, 2, 1, 3]) queries *= (decoder_args.size_per_head)**-0.5 dot = tf.matmul(queries, keys, transpose_b=True) attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)), dot.dtype) context = tf.matmul(attn, values) context = tf.transpose(context, [0, 2, 1, 3]) context = tf.reshape(context, [tf.shape(context)[0], 1, decoder_args.head_num * decoder_args.size_per_head]) outputs = tf.layers.conv1d(context, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) # drop_and_add input_dim = inputs.get_shape().as_list()[-1] output_dim = outputs.get_shape().as_list()[-1] if input_dim == output_dim: outputs += inputs last_context = outputs if memory is not None: for i, (mem, mask) in enumerate(zip(memory, memory_mask)): memory_cache = layer_cache["memory"][i] if layer_cache is not None else None with tf.variable_scope("multi_head" if i == 0 else "multi_head_%d" % i): queries = tf.layers.conv1d( norm(last_context), decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) def _project_and_split(): if fuse_qkv == True: keys, values = tf.split( tf.layers.conv1d(mem, decoder_args.hidden_dim * 2, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)), 2, axis=2) else: keys = tf.layers.conv1d(mem, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) values = tf.layers.conv1d(mem, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type), name="value") keys = tf.reshape(keys, [tf.shape(keys)[0], tf.shape(keys)[1], decoder_args.head_num, decoder_args.size_per_head]) keys = tf.transpose(keys, [0, 2, 1, 3]) values = tf.reshape(values, [tf.shape(values)[0], tf.shape(values)[1], decoder_args.head_num, decoder_args.size_per_head]) values = tf.transpose(values, [0, 2, 1, 3]) return keys, values keys, values = tf.cond( tf.equal( tf.shape(memory_cache["memory_keys"])[2], 0), true_fn=_project_and_split, false_fn=lambda: (memory_cache["memory_keys"], memory_cache["memory_values"])) memory_cache["memory_keys"] = keys memory_cache["memory_values"] = values queries = tf.reshape(queries, [tf.shape(queries)[0], 1,decoder_args.head_num, decoder_args.size_per_head]) queries = tf.transpose(queries, [0, 2, 1, 3]) queries *= (decoder_args.size_per_head)**-0.5 dot = tf.matmul(queries, keys, transpose_b=True) dot = tf.cast(tf.cast(dot, data_type) * mask + ((1.0 - mask) * data_type.min), dot.dtype) attn = tf.cast(tf.nn.softmax( tf.cast(dot, data_type)), dot.dtype) context = tf.matmul(attn, values) context = tf.transpose(context, [0, 2, 1, 3]) context = tf.reshape(context, [tf.shape(context)[0], 1, decoder_args.head_num * decoder_args.size_per_head]) context = tf.layers.conv1d(context, decoder_args.hidden_dim, 1, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) # drop_and_add input_dim = last_context.get_shape().as_list()[-1] output_dim = context.get_shape().as_list()[-1] if input_dim == output_dim: context += last_context with tf.variable_scope("ffn"): # forward normed_last_context = norm(context) input_dim = normed_last_context.get_shape().as_list()[-1] inner = tf.layers.conv1d(normed_last_context, decoder_args.hidden_dim * 4, 1, activation=tf.nn.relu, use_bias=True, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) transformed = tf.layers.conv1d(inner, input_dim, 1, use_bias=True, bias_initializer=create_initializer(b_init_range, data_type), kernel_initializer=create_initializer(k_init_range, data_type)) # drop_and_add input_dim = context.get_shape().as_list()[-1] output_dim = transformed.get_shape().as_list()[-1] if input_dim == output_dim: transformed += context inputs = transformed outputs = inputs return outputs
def attention_layer(from_tensor, to_tensor, attention_mask=None, num_attention_heads=1, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None, tf_datatype=tf.float32): def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") from_tensor_2d = reshape_to_matrix(from_tensor) to_tensor_2d = reshape_to_matrix(to_tensor) # `query_layer` = [B*F, N*H] query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, name="query", use_bias=True, bias_initializer=create_initializer(initializer_range, tf_datatype), kernel_initializer=create_initializer(initializer_range, tf_datatype)) # `key_layer` = [B*T, N*H] key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, name="key", use_bias=True, bias_initializer=create_initializer(initializer_range, tf_datatype), kernel_initializer=create_initializer(initializer_range, tf_datatype)) # `value_layer` = [B*T, N*H] value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, name="value", use_bias=True, bias_initializer=create_initializer(initializer_range, tf_datatype), kernel_initializer=create_initializer(initializer_range, tf_datatype)) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] if tf.rank(attention_mask) == 3: attention_mask = tf.expand_dims(attention_mask, axis=[1]) adder = (1.0 - tf.cast(attention_mask, tf_datatype)) * -10000.0 attention_scores += adder attention_probs = tf.nn.softmax(attention_scores) value_layer = tf.reshape( value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]) value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: context_layer = tf.reshape( context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head]) else: context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head]) return context_layer
def tf_encoder(input_tensor, encoder_args, attention_mask=None, intermediate_act_fn=gelu, initializer_range=0.02): ''' Run the bert transformer layer by TensorFlow. Args: inputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. The inputs tensor of encoder. The rank must be 3. encoder_args: The arguments for encoder. The details are in the class "TransformerArgument" of common.py attention_mask: A tf.Tensor. The attention mask for self attention. intermediate_act_fn: A callable function. The activation function in the FFN. It is gelu in BERT. initializer_range: A float value. The range of initializer for all weights. Outputs: outputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. The results of encoder. ''' intermediate_size = encoder_args.hidden_dim * 4 if encoder_args.hidden_dim % encoder_args.head_num != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (encoder_args.hidden_dim, encoder_args.head_num)) attention_head_size = int(encoder_args.hidden_dim / encoder_args.head_num) input_shape = get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] prev_output = reshape_to_matrix(input_tensor) for layer_idx in range(encoder_args.num_layer): with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE): layer_input = prev_output with tf.variable_scope("attention"): with tf.variable_scope("self"): attention_head = attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=encoder_args.head_num, size_per_head=encoder_args.size_per_head, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length, tf_datatype=encoder_args.dtype) attention_output = attention_head with tf.variable_scope("output"): attention_output = tf.layers.dense( attention_output, encoder_args.hidden_dim, use_bias=True, bias_initializer=create_initializer( initializer_range, encoder_args.dtype), kernel_initializer=create_initializer(initializer_range, encoder_args.dtype)) attention_output = layer_norm( attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.variable_scope("intermediate"): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, use_bias=True, bias_initializer=create_initializer( initializer_range, encoder_args.dtype), kernel_initializer=create_initializer(initializer_range, encoder_args.dtype)) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope("output"): layer_output = tf.layers.dense( intermediate_output, encoder_args.hidden_dim, use_bias=True, bias_initializer=create_initializer( initializer_range, encoder_args.dtype), kernel_initializer=create_initializer(initializer_range, encoder_args.dtype)) layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output prev_output = tf.reshape(prev_output, shape=tf.shape(input_tensor)) return prev_output
def tf_decoder(decoder_args, inputs, memory, memory_sequence_length, step, cache=None): ''' Run the decoder transformer layer by TensorFlow. Args: decoder_args: The arguments for decoder. The details are in the class "TransformerArgument" of common.py inputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. The inputs tensor of encoder. The rank must be 3. memory: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. The results of encoder transformer layer. The rank must be 3. Note that it must be extended by beam_width times memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. The lenght of each sentence of results of encoder. Note that it must be extended by beam_width times step: A tf.Tensor with tf.int type. The current step in the translation process. cache: A dict. The cache space to store the keys and values of attention layers. Outputs: outputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension]. The results of decoder. ''' k_init_range = decoder_args.kernel_init_range b_init_range = decoder_args.bias_init_range data_type = decoder_args.dtype memory_mask = None # has something if memory is not None and not tf.contrib.framework.nest.is_sequence( memory): memory = (memory, ) if memory_sequence_length is not None: if not tf.contrib.framework.nest.is_sequence( memory_sequence_length): memory_sequence_length = (memory_sequence_length, ) memory_mask = [ build_sequence_mask(length, num_heads=decoder_args.head_num, maximum_length=tf.shape(m)[1], data_type=data_type) for m, length in zip(memory, memory_sequence_length) ] for l in range(decoder_args.num_layer): layer_name = "layer_{}".format(l) layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("masked_multi_head"): norm_inputs = norm(inputs) queries, keys, values = tf.split(tf.layers.conv1d( norm_inputs, decoder_args.hidden_dim * 3, 1, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type)), 3, axis=2) keys = tf.reshape(keys, [ tf.shape(keys)[0], 1, decoder_args.head_num, decoder_args.size_per_head ]) keys = tf.transpose(keys, [0, 2, 1, 3]) values = tf.reshape(values, [ tf.shape(values)[0], 1, decoder_args.head_num, decoder_args.size_per_head ]) values = tf.transpose(values, [0, 2, 1, 3]) keys = tf.concat([layer_cache["self_keys"], keys], axis=2) values = tf.concat([layer_cache["self_values"], values], axis=2) layer_cache["self_keys"] = keys layer_cache["self_values"] = values queries = tf.reshape(queries, [ tf.shape(queries)[0], 1, decoder_args.head_num, decoder_args.size_per_head ]) queries = tf.transpose(queries, [0, 2, 1, 3]) queries *= (decoder_args.size_per_head)**-0.5 dot = tf.matmul(queries, keys, transpose_b=True) attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)), dot.dtype) context = tf.matmul(attn, values) context = tf.transpose(context, [0, 2, 1, 3]) context = tf.reshape(context, [ tf.shape(context)[0], 1, decoder_args.head_num * decoder_args.size_per_head ]) outputs = tf.layers.conv1d( context, decoder_args.hidden_dim, 1, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type)) # drop_and_add input_dim = inputs.get_shape().as_list()[-1] output_dim = outputs.get_shape().as_list()[-1] if input_dim == output_dim: outputs += inputs last_context = outputs # For GPT-2, we do not need cross attention # if memory is not None: # for i, (mem, mask) in enumerate(zip(memory, memory_mask)): # memory_cache = layer_cache["memory"][i] if layer_cache is not None else None # with tf.variable_scope("multi_head" if i == 0 else "multi_head_%d" % i): # queries = tf.layers.conv1d( # norm(last_context), # decoder_args.hidden_dim, # 1, # bias_initializer=create_initializer(b_init_range, data_type), # kernel_initializer=create_initializer(k_init_range, data_type)) # def _project_and_split(): # keys, values = tf.split( tf.layers.conv1d(mem, decoder_args.hidden_dim * 2, 1, # bias_initializer=create_initializer(b_init_range, data_type), # kernel_initializer=create_initializer(k_init_range, data_type)), 2, axis=2) # keys = tf.reshape(keys, [tf.shape(keys)[0], tf.shape(keys)[1], # decoder_args.head_num, decoder_args.size_per_head]) # keys = tf.transpose(keys, [0, 2, 1, 3]) # values = tf.reshape(values, [tf.shape(values)[0], tf.shape(values)[1], # decoder_args.head_num, decoder_args.size_per_head]) # values = tf.transpose(values, [0, 2, 1, 3]) # return keys, values # keys, values = tf.cond( # tf.equal( # tf.shape(memory_cache["memory_keys"])[2], 0), # true_fn=_project_and_split, # false_fn=lambda: (memory_cache["memory_keys"], memory_cache["memory_values"])) # memory_cache["memory_keys"] = keys # memory_cache["memory_values"] = values # queries = tf.reshape(queries, [tf.shape(queries)[0], 1,decoder_args.head_num, decoder_args.size_per_head]) # queries = tf.transpose(queries, [0, 2, 1, 3]) # queries *= (decoder_args.size_per_head)**-0.5 # # dot = tf.matmul(queries, keys, transpose_b=True) # dot = tf.cast(tf.cast(dot, data_type) * mask + # ((1.0 - mask) * data_type.min), dot.dtype) # attn = tf.cast(tf.nn.softmax( # tf.cast(dot, data_type)), dot.dtype) # context = tf.matmul(attn, values) # context = tf.transpose(context, [0, 2, 1, 3]) # context = tf.reshape(context, [tf.shape(context)[0], 1, # decoder_args.head_num * decoder_args.size_per_head]) # context = tf.layers.conv1d(context, # decoder_args.hidden_dim, # 1, # bias_initializer=create_initializer(b_init_range, data_type), # kernel_initializer=create_initializer(k_init_range, data_type)) # # drop_and_add # input_dim = last_context.get_shape().as_list()[-1] # output_dim = context.get_shape().as_list()[-1] # if input_dim == output_dim: # context += last_context with tf.variable_scope("ffn"): # forward # For GPT-2, take the inputs from the self attention # normed_last_context = norm(context) normed_last_context = norm(last_context) input_dim = normed_last_context.get_shape().as_list()[-1] # GPT-2 uses GELU # inner = tf.layers.conv1d(normed_last_context, # decoder_args.hidden_dim * 4, # 1, # activation=tf.nn.relu, # use_bias=True, # bias_initializer=create_initializer(b_init_range, data_type), # kernel_initializer=create_initializer(k_init_range, data_type)) inner = gelu( tf.layers.conv1d(normed_last_context, decoder_args.hidden_dim * 4, 1, use_bias=True, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type))) transformed = tf.layers.conv1d( inner, input_dim, 1, use_bias=True, bias_initializer=create_initializer( b_init_range, data_type), kernel_initializer=create_initializer( k_init_range, data_type)) # drop_and_add input_dim = context.get_shape().as_list()[-1] output_dim = transformed.get_shape().as_list()[-1] if input_dim == output_dim: # For GPT-2, residual connection comes from self attention # transformed += context transformed += last_context inputs = transformed outputs = inputs return outputs