def residual_mlp_layer(x_flat, intermediate_size, initializer_range=0.02, hidden_dropout_prob=0.1): """ :param x: The attention output. It should be [batch_size*seq_length, dim] :param intermediate_size: the hidden projection. By default this is the input_dim * 4. in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1) :return: """ batch_size_seq_length, hidden_size = get_shape_list(x_flat, expected_rank=2) x_norm = layer_norm(x_flat, name='mlp_ln0') intermediate_output = tf.layers.dense( x_norm, intermediate_size, activation=gelu, kernel_initializer=create_initializer(initializer_range), name='intermediate', ) output_for_residual = tf.layers.dense( intermediate_output, hidden_size, name='output', kernel_initializer=create_initializer(initializer_range)) output_for_residual = dropout(output_for_residual, hidden_dropout_prob) layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1') return layer_output
def embed(input_ids, vocab_size, embedding_size, position_offset=0, initializer_range=0.02, max_position_embeddings=512, use_one_hot_embeddings=True): """reur and position embeddings :param input_ids: int Tensor of shape [batch_size, seq_length]. :param vocab_size: number of words in vocab :param embedding_size: dimensionality of the embedding :param position_offset: aka number of cached tokens. :param initializer_range: float. Range of the weight initialization. :param max_position_embeddings: int. Maximum sequence length. :param use_one_hot_embeddings: probably want this to be true :return: [batch_size, seq_length, embedding_size] embedded tensor """ (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2) embedding_table = tf.get_variable( name='word_embed', shape=[vocab_size, embedding_size], initializer=create_initializer(initializer_range), ) assert_op = tf.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1) with tf.control_dependencies([assert_op]): if use_one_hot_embeddings: flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output_flat = tf.matmul(one_hot_input_ids, embedding_table) else: output_flat = tf.nn.embedding_lookup(embedding_table, input_ids) embedded_input = tf.reshape(output_flat, [batch_size, seq_length, embedding_size]) assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) with tf.control_dependencies([assert_op]): full_position_embeddings = tf.get_variable( name='pos_embed', shape=[max_position_embeddings, embedding_size], initializer=create_initializer(initializer_range), ) # Since the position embedding table is a learned variable, we create it # using a (long) sequence length `max_position_embeddings`. The actual # sequence length might be shorter than this, for faster training of # tasks that do not have long sequences. # # So `full_position_embeddings` is effectively an embedding table # for position [0, 1, 2, ..., max_position_embeddings-1], and the current # sequence has positions [0, 1, 2, ... seq_length-1], so we can just # perform a slice. if position_offset == 0: embedded_input += tf.slice(full_position_embeddings, [0, 0], [seq_length, -1])[None] else: # Tensorflow is too stupid to allow slicing flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) + position_offset) one_hot_pos_ids = tf.one_hot(flat_pos_ids, depth=max_position_embeddings) # [seq_length, full_position_embeddings], [full_position_embeddings, dim] seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings) embedded_input += seq_embeds[None] # embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None] return layer_norm(embedded_input, name='embed_norm'), embedding_table