def _project_and_split():
                            if fuse_qkv == True:
                                keys, values = tf.split(tf.layers.conv1d(
                                    mem,
                                    decoder_args.hidden_dim * 2,
                                    1,
                                    bias_initializer=create_initializer(
                                        b_init_range, data_type),
                                    kernel_initializer=create_initializer(
                                        k_init_range, data_type)),
                                                        2,
                                                        axis=2)
                            else:
                                keys = tf.layers.conv1d(
                                    mem,
                                    decoder_args.hidden_dim,
                                    1,
                                    bias_initializer=create_initializer(
                                        b_init_range, data_type),
                                    kernel_initializer=create_initializer(
                                        k_init_range, data_type))
                                values = tf.layers.conv1d(
                                    mem,
                                    decoder_args.hidden_dim,
                                    1,
                                    bias_initializer=create_initializer(
                                        b_init_range, data_type),
                                    kernel_initializer=create_initializer(
                                        k_init_range, data_type),
                                    name="value")

                            keys = tf.reshape(keys, [
                                tf.shape(keys)[0],
                                tf.shape(keys)[1], decoder_args.head_num,
                                decoder_args.size_per_head
                            ])
                            keys = tf.transpose(keys, [0, 2, 1, 3])
                            values = tf.reshape(values, [
                                tf.shape(values)[0],
                                tf.shape(values)[1], decoder_args.head_num,
                                decoder_args.size_per_head
                            ])
                            values = tf.transpose(values, [0, 2, 1, 3])

                            return keys, values
示例#2
0
def tf_decoder(decoder_args,
               inputs,
               memory,
               memory_sequence_length,
               step,
               cache=None):
    '''
    Run the decoder transformer layer by TensorFlow.
                      
    Args:
        decoder_args: The arguments for decoder. The details are in the class "TransformerArgument" of common.py
        inputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension].
                The inputs tensor of encoder. The rank must be 3.
        memory: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. 
                The results of encoder transformer layer. The rank must be 3. 
                Note that it must be extended by beam_width times
        memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. 
                                The lenght of each sentence of results of encoder. 
                                Note that it must be extended by beam_width times
        step: A tf.Tensor with tf.int type. The current step in the translation process.
        cache: A dict. The cache space to store the keys and values of attention layers.

    Outputs:
        outputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension].
                 The results of decoder.
    '''
    
    k_init_range = decoder_args.kernel_init_range
    b_init_range = decoder_args.bias_init_range
    data_type = decoder_args.dtype
    fuse_qkv = decoder_args.fuse_qkv
    hidden_dim = decoder_args.hidden_dim
    
    memory_mask = None  # has something

    if memory is not None and not tf.contrib.framework.nest.is_sequence(memory):
        memory = (memory,)
        if memory_sequence_length is not None:
            if not tf.contrib.framework.nest.is_sequence(memory_sequence_length):
                memory_sequence_length = (memory_sequence_length,)
            memory_mask = [
                build_sequence_mask(
                    length, num_heads=decoder_args.head_num, maximum_length=tf.shape(m)[1], data_type=data_type)
                for m, length in zip(memory, memory_sequence_length)]

    for l in range(decoder_args.num_layer):
        layer_name = "layer_{}".format(l)
        layer_cache = cache[layer_name] if cache is not None else None
        
        with tf.variable_scope(layer_name):
            with tf.variable_scope("masked_multi_head"):
                norm_inputs = norm(inputs)
                if fuse_qkv == True:
                    queries, keys, values = tf.split( tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim * 3, 1, 
                                                                bias_initializer=create_initializer(b_init_range, data_type),
                                                                kernel_initializer=create_initializer(k_init_range, data_type)), 3, axis=2)
                else:
                    '''
                    This progress wants to prevent a addictional tf.concat to concat the q, k, v kernels for decoder op 
                    becuase the concat bring large overhead for small batch size.
                    '''
                    queries = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, 
                                                bias_initializer=create_initializer(b_init_range, data_type),
                                                kernel_initializer=create_initializer(k_init_range, data_type))
                    keys = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, 
                                            bias_initializer=create_initializer(b_init_range, data_type),
                                            kernel_initializer=create_initializer(k_init_range, data_type),
                                            name="key")
                    values = tf.layers.conv1d(norm_inputs, decoder_args.hidden_dim, 1, 
                                                bias_initializer=create_initializer(b_init_range, data_type),
                                                kernel_initializer=create_initializer(k_init_range, data_type),
                                                name="value")

                keys = tf.reshape(keys, [tf.shape(keys)[0], 1, decoder_args.head_num, decoder_args.size_per_head])
                keys = tf.transpose(keys, [0, 2, 1, 3])
                values = tf.reshape(values, [tf.shape(values)[0], 1, decoder_args.head_num, decoder_args.size_per_head])
                values = tf.transpose(values, [0, 2, 1, 3])
                keys = tf.concat([layer_cache["self_keys"], keys], axis=2)
                values = tf.concat([layer_cache["self_values"], values], axis=2)
                layer_cache["self_keys"] = keys
                layer_cache["self_values"] = values

                queries = tf.reshape(queries, [tf.shape(queries)[0], 1, decoder_args.head_num, decoder_args.size_per_head])
                queries = tf.transpose(queries, [0, 2, 1, 3])
                queries *= (decoder_args.size_per_head)**-0.5

                dot = tf.matmul(queries, keys, transpose_b=True)

                attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)), dot.dtype)
                context = tf.matmul(attn, values)
                context = tf.transpose(context, [0, 2, 1, 3])
                context = tf.reshape(context, [tf.shape(context)[0], 1, decoder_args.head_num * decoder_args.size_per_head])

                outputs = tf.layers.conv1d(context,
                                            decoder_args.hidden_dim,
                                            1,
                                            bias_initializer=create_initializer(b_init_range, data_type),
                                            kernel_initializer=create_initializer(k_init_range, data_type))

                # drop_and_add
                input_dim = inputs.get_shape().as_list()[-1]
                output_dim = outputs.get_shape().as_list()[-1]
                if input_dim == output_dim:
                    outputs += inputs
                last_context = outputs

            if memory is not None:
                for i, (mem, mask) in enumerate(zip(memory, memory_mask)):
                    memory_cache = layer_cache["memory"][i] if layer_cache is not None else None

                    with tf.variable_scope("multi_head" if i == 0 else "multi_head_%d" % i):
                        queries = tf.layers.conv1d(
                            norm(last_context),
                            decoder_args.hidden_dim,
                            1,
                            bias_initializer=create_initializer(b_init_range, data_type),
                            kernel_initializer=create_initializer(k_init_range, data_type))

                        def _project_and_split():
                            if fuse_qkv == True:
                                keys, values = tf.split( tf.layers.conv1d(mem, decoder_args.hidden_dim * 2, 1, 
                                                                bias_initializer=create_initializer(b_init_range, data_type),
                                                                kernel_initializer=create_initializer(k_init_range, data_type)), 2, axis=2)
                            else:
                                keys = tf.layers.conv1d(mem, decoder_args.hidden_dim, 1, 
                                                        bias_initializer=create_initializer(b_init_range, data_type),
                                                        kernel_initializer=create_initializer(k_init_range, data_type))
                                values = tf.layers.conv1d(mem, decoder_args.hidden_dim, 1, 
                                                        bias_initializer=create_initializer(b_init_range, data_type),
                                                        kernel_initializer=create_initializer(k_init_range, data_type),
                                                        name="value")
                            

                            keys = tf.reshape(keys, [tf.shape(keys)[0], tf.shape(keys)[1],
                                                        decoder_args.head_num, decoder_args.size_per_head])
                            keys = tf.transpose(keys, [0, 2, 1, 3])
                            values = tf.reshape(values, [tf.shape(values)[0], tf.shape(values)[1],
                                                        decoder_args.head_num, decoder_args.size_per_head])
                            values = tf.transpose(values, [0, 2, 1, 3])

                            return keys, values

                        keys, values = tf.cond(
                            tf.equal(
                                tf.shape(memory_cache["memory_keys"])[2], 0),
                            true_fn=_project_and_split,
                            false_fn=lambda: (memory_cache["memory_keys"], memory_cache["memory_values"]))

                        memory_cache["memory_keys"] = keys
                        memory_cache["memory_values"] = values

                        queries = tf.reshape(queries, [tf.shape(queries)[0], 1,decoder_args.head_num, decoder_args.size_per_head])
                        queries = tf.transpose(queries, [0, 2, 1, 3])
                        queries *= (decoder_args.size_per_head)**-0.5
                        
                        dot = tf.matmul(queries, keys, transpose_b=True)
                        dot = tf.cast(tf.cast(dot, data_type) * mask +
                                      ((1.0 - mask) * data_type.min), dot.dtype)

                        attn = tf.cast(tf.nn.softmax(
                            tf.cast(dot, data_type)), dot.dtype)
                        context = tf.matmul(attn, values)
                        context = tf.transpose(context, [0, 2, 1, 3])
                        context = tf.reshape(context, [tf.shape(context)[0], 1,
                                                       decoder_args.head_num * decoder_args.size_per_head])
                        context = tf.layers.conv1d(context,
                                                    decoder_args.hidden_dim,
                                                    1,
                                                    bias_initializer=create_initializer(b_init_range, data_type),
                                                    kernel_initializer=create_initializer(k_init_range, data_type))

                        # drop_and_add
                        input_dim = last_context.get_shape().as_list()[-1]
                        output_dim = context.get_shape().as_list()[-1]
                        if input_dim == output_dim:
                            context += last_context

            with tf.variable_scope("ffn"):
                # forward
                normed_last_context = norm(context)
                input_dim = normed_last_context.get_shape().as_list()[-1]
                inner = tf.layers.conv1d(normed_last_context,
                                        decoder_args.hidden_dim * 4,
                                        1,
                                        activation=tf.nn.relu,
                                        use_bias=True,
                                        bias_initializer=create_initializer(b_init_range, data_type),
                                        kernel_initializer=create_initializer(k_init_range, data_type))
                transformed = tf.layers.conv1d(inner,
                                                input_dim,
                                                1,
                                                use_bias=True,
                                                bias_initializer=create_initializer(b_init_range, data_type),
                                                kernel_initializer=create_initializer(k_init_range, data_type))

                # drop_and_add
                input_dim = context.get_shape().as_list()[-1]
                output_dim = transformed.get_shape().as_list()[-1]
                if input_dim == output_dim:
                    transformed += context
        inputs = transformed
    outputs = inputs
    return outputs
示例#3
0
def attention_layer(from_tensor,
                    to_tensor,
                    attention_mask=None,
                    num_attention_heads=1,
                    size_per_head=512,
                    query_act=None,
                    key_act=None,
                    value_act=None,
                    attention_probs_dropout_prob=0.0,
                    initializer_range=0.02,
                    do_return_2d_tensor=False,
                    batch_size=None,
                    from_seq_length=None,
                    to_seq_length=None,
                    tf_datatype=tf.float32):
    
    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                             seq_length, width):
        output_tensor = tf.reshape(
            input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

    from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
            "The rank of `from_tensor` must match the rank of `to_tensor`.")

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if (batch_size is None or from_seq_length is None or to_seq_length is None):
            raise ValueError(
                "When passing in rank 2 tensors to attention_layer, the values "
                "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                "must all be specified.")

    from_tensor_2d = reshape_to_matrix(from_tensor)
    to_tensor_2d = reshape_to_matrix(to_tensor)

    # `query_layer` = [B*F, N*H]
    query_layer = tf.layers.dense(
        from_tensor_2d,
        num_attention_heads * size_per_head,
        activation=query_act,
        name="query",
        use_bias=True,
        bias_initializer=create_initializer(initializer_range, tf_datatype),
        kernel_initializer=create_initializer(initializer_range, tf_datatype))

    # `key_layer` = [B*T, N*H]
    key_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation=key_act,
        name="key",
        use_bias=True,
        bias_initializer=create_initializer(initializer_range, tf_datatype),
        kernel_initializer=create_initializer(initializer_range, tf_datatype))

    # `value_layer` = [B*T, N*H]
    value_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation=value_act,
        name="value",
        use_bias=True,
        bias_initializer=create_initializer(initializer_range, tf_datatype),
        kernel_initializer=create_initializer(initializer_range, tf_datatype))

    # `query_layer` = [B, N, F, H]
    query_layer = transpose_for_scores(query_layer, batch_size,
                                       num_attention_heads, from_seq_length,
                                       size_per_head)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
                                     to_seq_length, size_per_head)

    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                   1.0 / math.sqrt(float(size_per_head)))

    if attention_mask is not None:
        # `attention_mask` = [B, 1, F, T]
        if tf.rank(attention_mask) == 3:
            attention_mask = tf.expand_dims(attention_mask, axis=[1])
            
        adder = (1.0 - tf.cast(attention_mask, tf_datatype)) * -10000.0

        attention_scores += adder

    attention_probs = tf.nn.softmax(attention_scores)

    value_layer = tf.reshape(
        value_layer,
        [batch_size, to_seq_length, num_attention_heads, size_per_head])

    value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

    context_layer = tf.matmul(attention_probs, value_layer)

    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

    if do_return_2d_tensor:
        context_layer = tf.reshape(
            context_layer,
            [batch_size * from_seq_length, num_attention_heads * size_per_head])
    else:
        context_layer = tf.reshape(
            context_layer,
            [batch_size, from_seq_length, num_attention_heads * size_per_head])

    return context_layer
示例#4
0
def tf_encoder(input_tensor,
               encoder_args,
               attention_mask=None,
               intermediate_act_fn=gelu,
               initializer_range=0.02):
    '''
    Run the bert transformer layer by TensorFlow.
    
    Args:
        inputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. 
                The inputs tensor of encoder. The rank must be 3. 
        encoder_args: The arguments for encoder. The details are in the class 
                      "TransformerArgument" of common.py
        attention_mask: A tf.Tensor. The attention mask for self attention.
        intermediate_act_fn: A callable function.  
                             The activation function in the FFN. It is gelu in BERT. 
        initializer_range: A float value.     
                           The range of initializer for all weights.
        
    Outputs:
        outputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension].
                 The results of encoder.
    '''
    
    intermediate_size = encoder_args.hidden_dim * 4
    if encoder_args.hidden_dim % encoder_args.head_num != 0:
        raise ValueError(
            "The hidden size (%d) is not a multiple of the number of attention "
            "heads (%d)" % (encoder_args.hidden_dim, encoder_args.head_num))

    attention_head_size = int(encoder_args.hidden_dim / encoder_args.head_num)
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]

    prev_output = reshape_to_matrix(input_tensor)

    for layer_idx in range(encoder_args.num_layer):
        with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE):
            layer_input = prev_output
            with tf.variable_scope("attention"):
                with tf.variable_scope("self"):
                    attention_head = attention_layer(
                        from_tensor=layer_input,
                        to_tensor=layer_input,
                        attention_mask=attention_mask,
                        num_attention_heads=encoder_args.head_num,
                        size_per_head=encoder_args.size_per_head,
                        initializer_range=initializer_range,
                        do_return_2d_tensor=True,
                        batch_size=batch_size,
                        from_seq_length=seq_length,
                        to_seq_length=seq_length,
                        tf_datatype=encoder_args.dtype)
                    attention_output = attention_head

                with tf.variable_scope("output"):
                    attention_output = tf.layers.dense(
                        attention_output,
                        encoder_args.hidden_dim,
                        use_bias=True,
                        bias_initializer=create_initializer(
                            initializer_range, encoder_args.dtype),
                        kernel_initializer=create_initializer(initializer_range, encoder_args.dtype))
                    attention_output = layer_norm(
                        attention_output + layer_input)

            # The activation is only applied to the "intermediate" hidden layer.
            with tf.variable_scope("intermediate"):
                intermediate_output = tf.layers.dense(
                    attention_output,
                    intermediate_size,
                    activation=intermediate_act_fn,
                    use_bias=True,
                    bias_initializer=create_initializer(
                        initializer_range, encoder_args.dtype),
                    kernel_initializer=create_initializer(initializer_range, encoder_args.dtype))

            # Down-project back to `hidden_size` then add the residual.
            with tf.variable_scope("output"):
                layer_output = tf.layers.dense(
                    intermediate_output,
                    encoder_args.hidden_dim,
                    use_bias=True,
                    bias_initializer=create_initializer(
                        initializer_range, encoder_args.dtype),
                    kernel_initializer=create_initializer(initializer_range, encoder_args.dtype))
                layer_output = layer_norm(layer_output + attention_output)
                prev_output = layer_output

    prev_output = tf.reshape(prev_output, shape=tf.shape(input_tensor))
    return prev_output
示例#5
0
def tf_decoder(decoder_args,
               inputs,
               memory,
               memory_sequence_length,
               step,
               cache=None):
    '''
    Run the decoder transformer layer by TensorFlow.
                      
    Args:
        decoder_args: The arguments for decoder. The details are in the class "TransformerArgument" of common.py
        inputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension].
                The inputs tensor of encoder. The rank must be 3.
        memory: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension]. 
                The results of encoder transformer layer. The rank must be 3. 
                Note that it must be extended by beam_width times
        memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int. 
                                The lenght of each sentence of results of encoder. 
                                Note that it must be extended by beam_width times
        step: A tf.Tensor with tf.int type. The current step in the translation process.
        cache: A dict. The cache space to store the keys and values of attention layers.

    Outputs:
        outputs: A tf.Tensor with shape [batch_size * beam_width, 1, hidden_dimension].
                 The results of decoder.
    '''

    k_init_range = decoder_args.kernel_init_range
    b_init_range = decoder_args.bias_init_range
    data_type = decoder_args.dtype

    memory_mask = None  # has something

    if memory is not None and not tf.contrib.framework.nest.is_sequence(
            memory):
        memory = (memory, )
        if memory_sequence_length is not None:
            if not tf.contrib.framework.nest.is_sequence(
                    memory_sequence_length):
                memory_sequence_length = (memory_sequence_length, )
            memory_mask = [
                build_sequence_mask(length,
                                    num_heads=decoder_args.head_num,
                                    maximum_length=tf.shape(m)[1],
                                    data_type=data_type)
                for m, length in zip(memory, memory_sequence_length)
            ]

    for l in range(decoder_args.num_layer):
        layer_name = "layer_{}".format(l)
        layer_cache = cache[layer_name] if cache is not None else None

        with tf.variable_scope(layer_name):
            with tf.variable_scope("masked_multi_head"):
                norm_inputs = norm(inputs)
                queries, keys, values = tf.split(tf.layers.conv1d(
                    norm_inputs,
                    decoder_args.hidden_dim * 3,
                    1,
                    bias_initializer=create_initializer(
                        b_init_range, data_type),
                    kernel_initializer=create_initializer(
                        k_init_range, data_type)),
                                                 3,
                                                 axis=2)

                keys = tf.reshape(keys, [
                    tf.shape(keys)[0], 1, decoder_args.head_num,
                    decoder_args.size_per_head
                ])
                keys = tf.transpose(keys, [0, 2, 1, 3])
                values = tf.reshape(values, [
                    tf.shape(values)[0], 1, decoder_args.head_num,
                    decoder_args.size_per_head
                ])
                values = tf.transpose(values, [0, 2, 1, 3])
                keys = tf.concat([layer_cache["self_keys"], keys], axis=2)
                values = tf.concat([layer_cache["self_values"], values],
                                   axis=2)
                layer_cache["self_keys"] = keys
                layer_cache["self_values"] = values

                queries = tf.reshape(queries, [
                    tf.shape(queries)[0], 1, decoder_args.head_num,
                    decoder_args.size_per_head
                ])
                queries = tf.transpose(queries, [0, 2, 1, 3])
                queries *= (decoder_args.size_per_head)**-0.5

                dot = tf.matmul(queries, keys, transpose_b=True)

                attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)),
                               dot.dtype)
                context = tf.matmul(attn, values)
                context = tf.transpose(context, [0, 2, 1, 3])
                context = tf.reshape(context, [
                    tf.shape(context)[0], 1,
                    decoder_args.head_num * decoder_args.size_per_head
                ])

                outputs = tf.layers.conv1d(
                    context,
                    decoder_args.hidden_dim,
                    1,
                    bias_initializer=create_initializer(
                        b_init_range, data_type),
                    kernel_initializer=create_initializer(
                        k_init_range, data_type))

                # drop_and_add
                input_dim = inputs.get_shape().as_list()[-1]
                output_dim = outputs.get_shape().as_list()[-1]
                if input_dim == output_dim:
                    outputs += inputs
                last_context = outputs

            # For GPT-2, we do not need cross attention
            # if memory is not None:
            #     for i, (mem, mask) in enumerate(zip(memory, memory_mask)):
            #         memory_cache = layer_cache["memory"][i] if layer_cache is not None else None

            #         with tf.variable_scope("multi_head" if i == 0 else "multi_head_%d" % i):
            #             queries = tf.layers.conv1d(
            #                 norm(last_context),
            #                 decoder_args.hidden_dim,
            #                 1,
            #                 bias_initializer=create_initializer(b_init_range, data_type),
            #                 kernel_initializer=create_initializer(k_init_range, data_type))

            #             def _project_and_split():
            #                 keys, values = tf.split( tf.layers.conv1d(mem, decoder_args.hidden_dim * 2, 1,
            #                                                     bias_initializer=create_initializer(b_init_range, data_type),
            #                                                     kernel_initializer=create_initializer(k_init_range, data_type)), 2, axis=2)

            #                 keys = tf.reshape(keys, [tf.shape(keys)[0], tf.shape(keys)[1],
            #                                             decoder_args.head_num, decoder_args.size_per_head])
            #                 keys = tf.transpose(keys, [0, 2, 1, 3])
            #                 values = tf.reshape(values, [tf.shape(values)[0], tf.shape(values)[1],
            #                                             decoder_args.head_num, decoder_args.size_per_head])
            #                 values = tf.transpose(values, [0, 2, 1, 3])

            #                 return keys, values

            #             keys, values = tf.cond(
            #                 tf.equal(
            #                     tf.shape(memory_cache["memory_keys"])[2], 0),
            #                 true_fn=_project_and_split,
            #                 false_fn=lambda: (memory_cache["memory_keys"], memory_cache["memory_values"]))

            #             memory_cache["memory_keys"] = keys
            #             memory_cache["memory_values"] = values

            #             queries = tf.reshape(queries, [tf.shape(queries)[0], 1,decoder_args.head_num, decoder_args.size_per_head])
            #             queries = tf.transpose(queries, [0, 2, 1, 3])
            #             queries *= (decoder_args.size_per_head)**-0.5
            #
            #             dot = tf.matmul(queries, keys, transpose_b=True)
            #             dot = tf.cast(tf.cast(dot, data_type) * mask +
            #                           ((1.0 - mask) * data_type.min), dot.dtype)

            #             attn = tf.cast(tf.nn.softmax(
            #                 tf.cast(dot, data_type)), dot.dtype)
            #             context = tf.matmul(attn, values)
            #             context = tf.transpose(context, [0, 2, 1, 3])
            #             context = tf.reshape(context, [tf.shape(context)[0], 1,
            #                                            decoder_args.head_num * decoder_args.size_per_head])
            #             context = tf.layers.conv1d(context,
            #                                         decoder_args.hidden_dim,
            #                                         1,
            #                                         bias_initializer=create_initializer(b_init_range, data_type),
            #                                         kernel_initializer=create_initializer(k_init_range, data_type))

            #             # drop_and_add
            #             input_dim = last_context.get_shape().as_list()[-1]
            #             output_dim = context.get_shape().as_list()[-1]
            #             if input_dim == output_dim:
            #                 context += last_context

            with tf.variable_scope("ffn"):
                # forward
                # For GPT-2, take the inputs from the self attention
                # normed_last_context = norm(context)
                normed_last_context = norm(last_context)
                input_dim = normed_last_context.get_shape().as_list()[-1]
                # GPT-2 uses GELU
                # inner = tf.layers.conv1d(normed_last_context,
                #                         decoder_args.hidden_dim * 4,
                #                         1,
                #                         activation=tf.nn.relu,
                #                         use_bias=True,
                #                         bias_initializer=create_initializer(b_init_range, data_type),
                #                         kernel_initializer=create_initializer(k_init_range, data_type))
                inner = gelu(
                    tf.layers.conv1d(normed_last_context,
                                     decoder_args.hidden_dim * 4,
                                     1,
                                     use_bias=True,
                                     bias_initializer=create_initializer(
                                         b_init_range, data_type),
                                     kernel_initializer=create_initializer(
                                         k_init_range, data_type)))
                transformed = tf.layers.conv1d(
                    inner,
                    input_dim,
                    1,
                    use_bias=True,
                    bias_initializer=create_initializer(
                        b_init_range, data_type),
                    kernel_initializer=create_initializer(
                        k_init_range, data_type))

                # drop_and_add
                input_dim = context.get_shape().as_list()[-1]
                output_dim = transformed.get_shape().as_list()[-1]
                if input_dim == output_dim:
                    # For GPT-2, residual connection comes from self attention
                    # transformed += context
                    transformed += last_context
        inputs = transformed
    outputs = inputs
    return outputs