def _body(word_ids, cum_log_probs, finished, step, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache):
            # [batch_size * beam_width, hidden_dim]
            inputs = tf.nn.embedding_lookup(embedding_table, word_ids)
            # [batch_size * beam_width, 1, hidden_dim]
            inputs = tf.expand_dims(inputs, 1)
            
            inputs *= decoding_args.decoder_args.hidden_dim**0.5
            position_encoder = SinusoidalPositionEncoder()
            if position_encoder is not None:
                inputs = position_encoder(
                    inputs, position=step + 1 if step is not None else None)
                
            with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
                tf_result = tf_decoder(decoder_args=decoding_args.decoder_args,
                                    inputs=inputs,
                                    memory=extended_memory,
                                    memory_sequence_length=extended_memory_sequence_length,
                                    step=step,
                                    cache=my_cache,
                                    kernel_initializer_range=kernel_initializer_range,
                                    bias_initializer_range=bias_initializer_range)


            if decoder_type != 0:
                decoder_vars = tf.global_variables()
                decoder_vars_start_id = 0
                while decoder_vars_start_id < len(decoder_vars):
                    if decoder_vars[decoder_vars_start_id].name.find("transformer/decoding/decoder") != -1:
                        break
                    decoder_vars_start_id += 1
                decoder_vars = decoder_vars[decoder_vars_start_id:]

                psuedo_input = []
                if decoder_type == 2:
                    psuedo_input = tf_result
                    
                op_result, op_self_cache, op_mem_cache = op_decoder(inputs,
                                                                    step,
                                                                    extended_memory,
                                                                    extended_memory_sequence_length,
                                                                    op_self_cache,
                                                                    op_mem_cache,
                                                                    psuedo_input,
                                                                    decoder_vars,
                                                                    decoding_args.decoder_args,
                                                                    decoding_args.encoder_hidden_dim)

            result = None
            if decoder_type == 0:
                result = tf_result
            elif decoder_type == 1:
                result = op_result
            elif decoder_type == 2:
                result = tf_result
                result_2 = op_result
                
                flatten_result = tf.reshape(result, [-1])
                flatten_result_2 = tf.reshape(result_2, [-1])
                abs_diff = tf.math.abs(flatten_result - flatten_result_2)
                argmax = tf.math.argmax(abs_diff)
                result = tf.Print(result, ["[INFO][PYTHON] step:", step, "max diff: ", abs_diff[argmax],
                                           " op val: ", flatten_result_2[argmax],
                                           " tf val: ", flatten_result[argmax], 
                                           tf.cond(abs_diff[argmax] < atol_threshold, lambda: "True", lambda: "False")])
            else:
                print("[TF][ERROR] decoder type is only 0 or 1 or 2.")
                exit(-1)

            result = tf.contrib.layers.layer_norm(result, begin_norm_axis=-1)
            # [batch_size * beam_width, hidden_dim]
            result = tf.squeeze(result, axis=1)
            logits = tf.layers.dense(result,
                                     decoding_args.vocab_size,
                                     use_bias=True,
                                     bias_initializer=create_initializer(
                                         bias_initializer_range, decoding_args.decoder_args.dtype),
                                     kernel_initializer=create_initializer(
                                         kernel_initializer_range, decoding_args.decoder_args.dtype),
                                     activation=None)

            end_ids = tf.fill([decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width],
                              decoding_args.end_id)  # [batch_size * beam_width]
            eos_max_prob = tf.one_hot(end_ids, decoding_args.vocab_size,
                                      on_value=decoding_args.decoder_args.dtype.max,
                                      off_value=decoding_args.decoder_args.dtype.min)  # [batch_size * beam_width, vocab_size]
            # [batch_size * beam_width, vocab_size]
            logits = tf.where(finished, x=eos_max_prob, y=logits)
            logits = tf.cast(logits, tf.float32)
            # [batch_size * beam_width, vocab_size]
            log_probs = tf.nn.log_softmax(logits)

            output_id, next_cum_log_probs, finished, my_cache, \
                extra_vars, op_self_cache = beam_search(decoding_args.decoder_args.beam_width,
                                                        decoding_args.vocab_size,
                                                        step,
                                                        log_probs,
                                                        cum_log_probs,
                                                        finished,
                                                        my_cache,
                                                        extra_vars,
                                                        op_self_cache)

            outputs = outputs.write(step, output_id)
            cum_log_probs = tf.where(
                finished, x=cum_log_probs, y=next_cum_log_probs)
            finished = tf.logical_or(finished, tf.equal(
                output_id, decoding_args.end_id))

            return output_id, cum_log_probs, finished, step + 1, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache
Пример #2
0
def tf_encoder_opennmt(input_tensor,
                        encoder_args,
                        initializer_range=0.02,
                        sequence_length=None):
    '''
    Run the bert transformer layer by TensorFlow.
    
    Args:
        input_tensor: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension]. 
                       The inputs tensor of encoder. The rank must be 3. 
        encoder_args: The arguments for encoder. The details are in the class 
                      "TransformerArgument" of common.py
        initializer_range: A float value.     
                           The range of initializer for all weights.
        sequence_length: A tf.Tensor with shape [batch_size], with tf.int type.
                         The sequence length of each sentence in input_tensor.
        
    Outputs:
        output: A tf.Tensor with shape [batch_size, max(sequence_length), hidden_dimension].
                The results of encoder.
    '''
    
    data_type = encoder_args.dtype
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    
    input_tensor *= encoder_args.hidden_dim**0.5
    position_encoder = SinusoidalPositionEncoder()
    input_tensor = position_encoder(input_tensor, position=tf.range(seq_length))
    
    mask = build_sequence_mask(
        sequence_length,
        encoder_args.head_num,
        maximum_length=tf.shape(input_tensor)[1],
        dtype=data_type)
    
    intermediate_size = encoder_args.hidden_dim * 4
    if encoder_args.hidden_dim % encoder_args.head_num != 0:
        raise ValueError(
            "The hidden size (%d) is not a multiple of the number of attention "
            "heads (%d)" % (encoder_args.hidden_dim, encoder_args.head_num))

    layer_input = input_tensor
    for layer_idx in range(encoder_args.num_layer):
        with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE):
            with tf.variable_scope("multi_head"):
                normed_input = tf.cast(layer_norm(tf.cast(layer_input, tf.float32)), data_type)
                
                queries, keys, values = tf.split(tf.layers.conv1d(normed_input, encoder_args.hidden_dim * 3, 1), 3, axis=2)
                
                # split head
                queries = tf.reshape(queries, [batch_size, seq_length, encoder_args.head_num, encoder_args.size_per_head])
                queries = tf.transpose(queries, [0, 2, 1, 3])
                
                keys = tf.reshape(keys, [batch_size, seq_length, encoder_args.head_num, encoder_args.size_per_head])
                keys = tf.transpose(keys, [0, 2, 1, 3])
                
                values = tf.reshape(values, [batch_size, seq_length, encoder_args.head_num, encoder_args.size_per_head])
                values = tf.transpose(values, [0, 2, 1, 3])
                
                queries *= (encoder_args.size_per_head)**-0.5

                dot = tf.matmul(queries, keys, transpose_b=True)
                
                if mask is not None:
                    dot = tf.cast(tf.cast(dot, data_type) * mask + ((1.0 - mask) * data_type.min), dot.dtype)

                attn = tf.cast(tf.nn.softmax(tf.cast(dot, data_type)), dot.dtype)

                context_1 = tf.matmul(attn, values)
                context_1 = tf.transpose(context_1, [0, 2, 1, 3])
                context_1 = tf.reshape(context_1, [batch_size, seq_length, encoder_args.hidden_dim])
                attention_output = tf.layers.conv1d(context_1, encoder_args.hidden_dim, 1)
                context_2 = attention_output + layer_input
                
            with tf.variable_scope("ffn"):
                normed_context_2 = tf.cast(layer_norm(tf.cast(context_2, tf.float32)), data_type)
                intermediate_output = tf.layers.conv1d(normed_context_2, intermediate_size, 1, activation=tf.nn.relu)
                layer_output_1 = tf.layers.conv1d(intermediate_output, encoder_args.hidden_dim, 1)
                layer_output_2 = layer_output_1 + context_2
                layer_input = layer_output_2
                
    layer_input = tf.cast(layer_input, tf.float32)
    output = layer_norm(layer_input, name="LayerNorm")
    return output
Пример #3
0
def op_decoder(inputs, step, memory_tensor, memory_sequence_length,
               op_self_cache, op_mem_cache, psuedo_input, decoder_vars,
               decoder_args, memory_hidden_dim):

    decoder_op_module = tf.load_op_library(
        os.path.join('./lib/libtf_decoder.so'))

    position_encoder = SinusoidalPositionEncoder()
    inputs = position_encoder(inputs,
                              position=step + 1 if step is not None else None)

    op_self_cache = tf.concat([
        op_self_cache,
        tf.zeros([
            decoder_args.num_layer, 2, 1, decoder_args.batch_size *
            decoder_args.beam_width, decoder_args.hidden_dim
        ],
                 dtype=decoder_args.dtype)
    ],
                              axis=2)

    for i in range(decoder_args.num_layer):
        op_result, _, _ = decoder_op_module.decoder(
            inputs,
            memory_tensor,
            memory_sequence_length,
            decoder_vars[0 + 26 * i],
            decoder_vars[1 + 26 * i],
            decoder_vars[2 + 26 * i],
            decoder_vars[3 + 26 * i],
            decoder_vars[4 + 26 * i],
            decoder_vars[5 + 26 * i],
            decoder_vars[6 + 26 * i],
            decoder_vars[7 + 26 * i],
            decoder_vars[8 + 26 * i],
            decoder_vars[9 + 26 * i],
            decoder_vars[10 + 26 * i],
            decoder_vars[11 + 26 * i],
            decoder_vars[12 + 26 * i],
            decoder_vars[13 + 26 * i],
            decoder_vars[14 + 26 * i],
            decoder_vars[15 + 26 * i],
            decoder_vars[16 + 26 * i],
            decoder_vars[17 + 26 * i],
            decoder_vars[18 + 26 * i],
            decoder_vars[19 + 26 * i],
            decoder_vars[20 + 26 * i],
            decoder_vars[21 + 26 * i],
            decoder_vars[22 + 26 * i],
            decoder_vars[23 + 26 * i],
            decoder_vars[24 + 26 * i],
            decoder_vars[25 + 26 * i],
            op_self_cache[i],
            op_mem_cache[i],
            psuedo_input,  # add tf_result as input to prevent the OP and TF from parallel execution and lead to error result
            max_seq_len=decoder_args.max_seq_len,
            head_num=decoder_args.head_num,
            size_per_head=decoder_args.size_per_head,
            memory_hidden_dim=memory_hidden_dim)
        inputs = op_result

    return op_result, op_self_cache, op_mem_cache
Пример #4
0
def tf_decoder(decoder_args,
               inputs,
               memory,
               memory_sequence_length,
               step,
               cache=None,
               kernel_initializer_range=0.02,
               bias_initializer_range=0):

    position_encoder = SinusoidalPositionEncoder()
    if position_encoder is not None:
        inputs = position_encoder(inputs,
                                  position=step +
                                  1 if step is not None else None)

    memory_mask = None  # has something

    if memory is not None and not tf.contrib.framework.nest.is_sequence(
            memory):
        memory = (memory, )
        if memory_sequence_length is not None:
            if not tf.contrib.framework.nest.is_sequence(
                    memory_sequence_length):
                memory_sequence_length = (memory_sequence_length, )
            memory_mask = [
                build_sequence_mask(length,
                                    num_heads=decoder_args.head_num,
                                    maximum_length=tf.shape(m)[1],
                                    data_type=decoder_args.dtype)
                for m, length in zip(memory, memory_sequence_length)
            ]

    for l in range(decoder_args.num_layer):
        layer_name = "layer_{}".format(l)
        layer_cache = cache[layer_name] if cache is not None else None

        with tf.variable_scope(layer_name):
            with tf.variable_scope("masked_multi_head"):
                norm_inputs = norm(inputs)
                queries = tf.layers.conv1d(
                    norm_inputs,
                    decoder_args.hidden_dim,
                    1,
                    activation=None,
                    name="query",
                    use_bias=True,
                    bias_initializer=create_initializer(
                        bias_initializer_range, decoder_args.dtype),
                    kernel_initializer=create_initializer(
                        kernel_initializer_range, decoder_args.dtype))

                keys = tf.layers.conv1d(norm_inputs,
                                        decoder_args.hidden_dim,
                                        1,
                                        activation=None,
                                        name="key",
                                        use_bias=True,
                                        bias_initializer=create_initializer(
                                            bias_initializer_range,
                                            decoder_args.dtype),
                                        kernel_initializer=create_initializer(
                                            kernel_initializer_range,
                                            decoder_args.dtype))

                values = tf.layers.conv1d(
                    norm_inputs,
                    decoder_args.hidden_dim,
                    1,
                    activation=None,
                    name="value",
                    use_bias=True,
                    bias_initializer=create_initializer(
                        bias_initializer_range, decoder_args.dtype),
                    kernel_initializer=create_initializer(
                        kernel_initializer_range, decoder_args.dtype))

                keys = tf.reshape(keys, [
                    decoder_args.batch_size * decoder_args.beam_width, 1,
                    decoder_args.head_num, decoder_args.size_per_head
                ])
                keys = tf.transpose(keys, [0, 2, 1, 3])
                values = tf.reshape(values, [
                    decoder_args.batch_size * decoder_args.beam_width, 1,
                    decoder_args.head_num, decoder_args.size_per_head
                ])
                values = tf.transpose(values, [0, 2, 1, 3])

                keys = tf.concat([layer_cache["self_keys"], keys], axis=2)
                values = tf.concat([layer_cache["self_values"], values],
                                   axis=2)
                layer_cache["self_keys"] = keys
                layer_cache["self_values"] = values

                queries = tf.reshape(queries, [
                    decoder_args.batch_size * decoder_args.beam_width, 1,
                    decoder_args.head_num, decoder_args.size_per_head
                ])
                queries = tf.transpose(queries, [0, 2, 1, 3])
                queries *= (decoder_args.size_per_head)**-0.5

                dot = tf.matmul(queries, keys, transpose_b=True)

                attn = tf.cast(tf.nn.softmax(tf.cast(dot, decoder_args.dtype)),
                               dot.dtype)
                context = tf.matmul(attn, values)
                context = tf.transpose(context, [0, 2, 1, 3])
                context = tf.reshape(context, [
                    decoder_args.batch_size * decoder_args.beam_width, 1,
                    decoder_args.head_num * decoder_args.size_per_head
                ])

                outputs = tf.layers.conv1d(
                    context,
                    decoder_args.hidden_dim,
                    1,
                    activation=None,
                    use_bias=True,
                    bias_initializer=create_initializer(
                        bias_initializer_range, decoder_args.dtype),
                    kernel_initializer=create_initializer(
                        kernel_initializer_range, decoder_args.dtype))

                # drop_and_add
                input_dim = inputs.get_shape().as_list()[-1]
                output_dim = outputs.get_shape().as_list()[-1]
                if input_dim == output_dim:
                    outputs += inputs
                last_context = outputs

            if memory is not None:
                for i, (mem, mask) in enumerate(zip(memory, memory_mask)):
                    memory_cache = layer_cache["memory"][
                        i] if layer_cache is not None else None

                    with tf.variable_scope("multi_head" if i ==
                                           0 else "multi_head_%d" % i):
                        queries = tf.layers.conv1d(
                            norm(last_context),
                            decoder_args.hidden_dim,
                            1,
                            activation=None,
                            name="query",
                            use_bias=True,
                            bias_initializer=create_initializer(
                                bias_initializer_range, decoder_args.dtype),
                            kernel_initializer=create_initializer(
                                kernel_initializer_range, decoder_args.dtype))

                        def _project_and_split():
                            keys = tf.layers.conv1d(
                                mem,
                                decoder_args.hidden_dim,
                                1,
                                activation=None,
                                name="key",
                                use_bias=True,
                                bias_initializer=create_initializer(
                                    bias_initializer_range,
                                    decoder_args.dtype),
                                kernel_initializer=create_initializer(
                                    kernel_initializer_range,
                                    decoder_args.dtype))

                            values = tf.layers.conv1d(
                                mem,
                                decoder_args.hidden_dim,
                                1,
                                activation=None,
                                name="value",
                                use_bias=True,
                                bias_initializer=create_initializer(
                                    bias_initializer_range,
                                    decoder_args.dtype),
                                kernel_initializer=create_initializer(
                                    kernel_initializer_range,
                                    decoder_args.dtype))

                            keys = tf.reshape(keys, [
                                decoder_args.batch_size *
                                decoder_args.beam_width,
                                tf.shape(keys)[1], decoder_args.head_num,
                                decoder_args.size_per_head
                            ])
                            keys = tf.transpose(keys, [0, 2, 1, 3])
                            values = tf.reshape(values, [
                                decoder_args.batch_size *
                                decoder_args.beam_width,
                                tf.shape(values)[1], decoder_args.head_num,
                                decoder_args.size_per_head
                            ])
                            values = tf.transpose(values, [0, 2, 1, 3])

                            return keys, values

                        keys, values = tf.cond(tf.equal(
                            tf.shape(memory_cache["memory_keys"])[2], 0),
                                               true_fn=_project_and_split,
                                               false_fn=lambda:
                                               (memory_cache["memory_keys"],
                                                memory_cache["memory_values"]))

                        memory_cache["memory_keys"] = keys
                        memory_cache["memory_values"] = values

                        queries = tf.reshape(queries, [
                            decoder_args.batch_size * decoder_args.beam_width,
                            1, decoder_args.head_num,
                            decoder_args.size_per_head
                        ])
                        queries = tf.transpose(queries, [0, 2, 1, 3])
                        queries *= (decoder_args.size_per_head)**-0.5

                        dot = tf.matmul(queries, keys, transpose_b=True)

                        dot = tf.cast(
                            tf.cast(dot, decoder_args.dtype) * mask +
                            ((1.0 - mask) * decoder_args.dtype.min), dot.dtype)

                        attn = tf.cast(
                            tf.nn.softmax(tf.cast(dot, decoder_args.dtype)),
                            dot.dtype)
                        context = tf.matmul(attn, values)
                        context = tf.transpose(context, [0, 2, 1, 3])
                        context = tf.reshape(context, [
                            decoder_args.batch_size * decoder_args.beam_width,
                            1,
                            decoder_args.head_num * decoder_args.size_per_head
                        ])
                        context = tf.layers.conv1d(
                            context,
                            decoder_args.hidden_dim,
                            1,
                            activation=None,
                            use_bias=True,
                            bias_initializer=create_initializer(
                                bias_initializer_range, decoder_args.dtype),
                            kernel_initializer=create_initializer(
                                kernel_initializer_range, decoder_args.dtype))

                        # drop_and_add
                        input_dim = last_context.get_shape().as_list()[-1]
                        output_dim = context.get_shape().as_list()[-1]
                        if input_dim == output_dim:
                            context += last_context

            with tf.variable_scope("ffn"):
                # forward
                normed_last_context = norm(context)
                input_dim = normed_last_context.get_shape().as_list()[-1]
                inner = tf.layers.conv1d(normed_last_context,
                                         decoder_args.hidden_dim * 4,
                                         1,
                                         activation=tf.nn.relu,
                                         use_bias=True,
                                         bias_initializer=create_initializer(
                                             bias_initializer_range,
                                             decoder_args.dtype),
                                         kernel_initializer=create_initializer(
                                             kernel_initializer_range,
                                             decoder_args.dtype))
                transformed = tf.layers.conv1d(
                    inner,
                    input_dim,
                    1,
                    use_bias=True,
                    bias_initializer=create_initializer(
                        bias_initializer_range, decoder_args.dtype),
                    kernel_initializer=create_initializer(
                        kernel_initializer_range, decoder_args.dtype))

                # drop_and_add
                input_dim = context.get_shape().as_list()[-1]
                output_dim = transformed.get_shape().as_list()[-1]
                if input_dim == output_dim:
                    transformed += context

        inputs = transformed
    outputs = inputs
    return outputs
Пример #5
0
def ft_encoder_opennmt(inputs, encoder_args, encoder_vars_dict,
                       sequence_length):
    '''
    Run the bert transformer layer by FasterTransformer.

    Args:
        inputs: A tf.Tensor with shape [batch_size, seq_len, hidden_dimension].
                The inputs tensor of encoder. The rank must be 3. 
        encoder_args: The arguments for encoder. The details are in the class "TransformerArgument" of common.py
        attention_mask: A tf.Tensor. The attention mask for self attention.
        encoder_vars_dict: A dict of tf.Tensor or numpy array. 
                            The variables for encoder. They can be either some tensor or some numpy array. 
                            The key is the name of the tensor, like 'layer_0/attention/self/query/kernel:0'.
                            Teh value is the corresponding tensor or numpy array
        sequence_length: A tf.Tensor or numpy array with shape [batch_size].
                        The sequence length of the sentences
    Outputs:
        outputs: A tensor with shape [batch_size, seq_len, hidden_dimension].
                 The results of encoder.
    '''

    attention_mask = build_sequence_mask(sequence_length,
                                         encoder_args.head_num,
                                         maximum_length=tf.shape(inputs)[1],
                                         dtype=encoder_args.dtype)

    inputs *= encoder_args.hidden_dim**0.5
    position_encoder = SinusoidalPositionEncoder()
    inputs = position_encoder(inputs, position=tf.range(tf.shape(inputs)[1]))

    remove_padding = encoder_args.remove_padding
    transformer_op_module = tf.load_op_library(
        os.path.join('./lib/libtf_fastertransformer.so'))
    if remove_padding == True:
        inputs, sequence_id_offset = transformer_op_module.build_mask_remove_padding(
            inputs, sequence_length)
        trt_seq_len = tf.cumsum(tf.concat([[0], sequence_length], axis=0),
                                axis=0)
    else:
        sequence_id_offset = []
        batch = tf.shape(inputs)[0]
        max_seq_len = tf.shape(inputs)[1]
        padding_offset = tf.range(0, batch * max_seq_len, max_seq_len)
        squence_offset_with_padding = sequence_length + padding_offset
        c = tf.concat([padding_offset, squence_offset_with_padding], axis=0)
        c_r = tf.reshape(c, [2, -1])
        t = tf.transpose(c_r)
        trt_seq_len = tf.reshape(t, [-1])
        trt_seq_len = tf.concat([trt_seq_len, [batch * max_seq_len]], axis=0)

    for layer_idx in range(encoder_args.num_layer):
        if encoder_args.int8_mode != 0:
            amaxList = encoder_vars_dict['layer_%d/amaxList:0' % layer_idx]
        else:
            amaxList = []

        if tf.is_tensor(encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d/kernel:0' %
                layer_idx]) == True:
            q_w, k_w, v_w = tf.split(encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d/kernel:0' %
                layer_idx],
                                     3,
                                     axis=-1)
            q_b, k_b, v_b = tf.split(encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d/bias:0' %
                layer_idx],
                                     3,
                                     axis=-1)
        else:
            q_w, k_w, v_w = np.split(encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d/kernel:0' %
                layer_idx],
                                     3,
                                     axis=-1)
            q_b, k_b, v_b = np.split(encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d/bias:0' %
                layer_idx],
                                     3,
                                     axis=-1)
        outputs = transformer_op_module.open_encoder(
            inputs,
            inputs,
            tf.cast(
                encoder_vars_dict[
                    'transformer/encoder/layer_%d/multi_head/LayerNorm/beta:0'
                    % layer_idx], encoder_args.dtype),
            tf.cast(
                encoder_vars_dict[
                    'transformer/encoder/layer_%d/multi_head/LayerNorm/gamma:0'
                    % layer_idx], encoder_args.dtype),
            q_w,
            q_b,
            k_w,
            k_b,
            v_w,
            v_b,
            attention_mask,
            encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d_1/kernel:0' %
                layer_idx],
            encoder_vars_dict[
                'transformer/encoder/layer_%d/multi_head/conv1d_1/bias:0' %
                layer_idx],
            tf.cast(
                encoder_vars_dict[
                    'transformer/encoder/layer_%d/ffn/LayerNorm/beta:0' %
                    layer_idx], encoder_args.dtype),
            tf.cast(
                encoder_vars_dict[
                    'transformer/encoder/layer_%d/ffn/LayerNorm/gamma:0' %
                    layer_idx], encoder_args.dtype),
            encoder_vars_dict[
                'transformer/encoder/layer_%d/ffn/conv1d/kernel:0' %
                layer_idx],
            encoder_vars_dict['transformer/encoder/layer_%d/ffn/conv1d/bias:0'
                              % layer_idx],
            encoder_vars_dict[
                'transformer/encoder/layer_%d/ffn/conv1d_1/kernel:0' %
                layer_idx],
            encoder_vars_dict[
                'transformer/encoder/layer_%d/ffn/conv1d_1/bias:0' %
                layer_idx],
            sequence_id_offset,
            amaxList,
            trt_seq_len,
            head_num=encoder_args.head_num,
            size_per_head=encoder_args.size_per_head,
            remove_padding=remove_padding,
            int8_mode=encoder_args.int8_mode,
            layer_idx=layer_idx,
            layer_num=encoder_args.num_layer,
            allow_gemm_test=encoder_args.allow_gemm_test)
        inputs = outputs
    if remove_padding == True:
        outputs = transformer_op_module.rebuild_padding(
            outputs, sequence_id_offset, attention_mask)

    outputs = tf.cast(outputs, tf.float32)
    outputs = layer_norm(outputs, name="LayerNorm")
    outputs = tf.cast(outputs, encoder_args.dtype)
    return outputs