示例#1
0
    embedding_table = tf.convert_to_tensor(embedding_table)
    memory, memory_sequence_length = generate_encoder_result(
        batch_size, max_seq_len, memory_hidden_dim, tf_datatype)

    finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, \
        tf_parent_ids, tf_sequence_lengths = tf_beamsearch_decoding(memory,
                                                                    memory_sequence_length,
                                                                    embedding_table,
                                                                    decoding_args,
                                                                    decoder_type=0)

    all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    finalized_op_output_ids, finalized_op_sequence_lengths, op_output_ids, \
        op_parent_ids, op_sequence_lengths = op_beamsearch_decoding(memory,
                                                         memory_sequence_length,
                                                         embedding_table,
                                                         all_vars,
                                                         decoding_args)

    tf_sampling_target_ids, tf_sampling_target_length = tf_sampling_decoding(
        memory,
        memory_sequence_length,
        embedding_table,
        decoding_sampling_args,
        decoder_type=0)

    op_sampling_target_ids, op_sampling_target_length = op_sampling_decoding(
        memory, memory_sequence_length, embedding_table, all_vars,
        decoding_sampling_args)

    config = tf.ConfigProto()
def translate_sample(args_dict):
    print("\n=============== Argument ===============")
    for key in args_dict:
        print("{}: {}".format(key, args_dict[key]))
    print("========================================")

    np.random.seed(1)
    tf.set_random_seed(1)
    random.seed(1)

    start_of_sentence_id = 1
    end_of_sentence_id = 2

    kernel_initializer_range = 0.02
    bias_initializer_range = 0.02

    batch_size = args_dict['batch_size']
    beam_width = args_dict['beam_width']
    max_seq_len = args_dict['max_seq_len']
    encoder_head_num = args_dict['encoder_head_number']
    encoder_size_per_head = args_dict['encoder_size_per_head']
    decoder_head_num = args_dict['decoder_head_number']
    decoder_size_per_head = args_dict['decoder_size_per_head']
    encoder_num_layer = args_dict['encoder_num_layer']
    decoder_num_layer = args_dict['decoder_num_layer']
    encoder_hidden_dim = encoder_head_num * encoder_size_per_head
    decoder_hidden_dim = decoder_head_num * decoder_size_per_head
    tf_datatype = tf.float32
    if args_dict['data_type'] == "fp16":
        tf_datatype = tf.float16
    beam_search_diversity_rate = args_dict['beam_search_diversity_rate']
    sampling_topk = args_dict['sampling_topk']
    sampling_topp = args_dict['sampling_topp']

    source_inputter = WordEmbedder("source_vocabulary",
                                   embedding_size=encoder_hidden_dim,
                                   dtype=tf_datatype)
    target_inputter = WordEmbedder("target_vocabulary",
                                   embedding_size=decoder_hidden_dim,
                                   dtype=tf_datatype)
    inputter = ExampleInputter(source_inputter, target_inputter)
    inputter.initialize({
        "source_vocabulary": args_dict['source_vocabulary'],
        "target_vocabulary": args_dict['target_vocabulary']
    })
    vocab_size = target_inputter.vocabulary_size
    source_file = args_dict['source']
    is_remove_padding = True if args_dict['remove_padding'].lower(
    ) == "true" else False

    encoder_args = TransformerArgument(
        beam_width=1,
        head_num=encoder_head_num,
        size_per_head=encoder_size_per_head,
        num_layer=encoder_num_layer,
        dtype=tf_datatype,
        kernel_init_range=kernel_initializer_range,
        bias_init_range=bias_initializer_range,
        remove_padding=is_remove_padding)

    decoder_args = TransformerArgument(
        beam_width=beam_width,
        head_num=decoder_head_num,
        size_per_head=decoder_size_per_head,
        num_layer=decoder_num_layer,
        dtype=tf_datatype,
        kernel_init_range=kernel_initializer_range,
        bias_init_range=bias_initializer_range,
        memory_hidden_dim=encoder_head_num * encoder_size_per_head)

    decoder_args_2 = copy.deepcopy(decoder_args)  # for beam search
    decoder_args_2.__dict__ = copy.deepcopy(decoder_args.__dict__)
    decoder_args_2.beam_width = 1  # for sampling

    decoding_beamsearch_args = DecodingBeamsearchArgument(
        vocab_size, start_of_sentence_id, end_of_sentence_id, max_seq_len,
        decoder_args, beam_search_diversity_rate)

    decoding_sampling_args = DecodingSamplingArgument(
        vocab_size, start_of_sentence_id, end_of_sentence_id, max_seq_len,
        decoder_args_2, sampling_topk, sampling_topp)

    with tf.variable_scope("transformer/encoder", reuse=tf.AUTO_REUSE):
        dataset = inputter.make_inference_dataset(source_file, batch_size)
        iterator = dataset.make_initializable_iterator()
        source = iterator.get_next()
        source_embedding = source_inputter.make_inputs(source)
        source_embedding = tf.cast(source_embedding, tf_datatype)
        memory_sequence_length = source["length"]

        tf_encoder_result = tf_encoder_opennmt(
            source_embedding,
            encoder_args,
            sequence_length=memory_sequence_length)

        encoder_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        encoder_variables_dict = {}
        for v in encoder_vars:
            encoder_variables_dict[v.name] = v

        ft_encoder_result = ft_encoder_opennmt(
            inputs=source_embedding,
            encoder_args=encoder_args,
            encoder_vars_dict=encoder_variables_dict,
            sequence_length=memory_sequence_length)

    tf_encoder_result = tf.reshape(tf_encoder_result,
                                   tf.shape(source_embedding))
    ft_encoder_result = tf.reshape(ft_encoder_result,
                                   tf.shape(source_embedding))

    with tf.variable_scope("transformer/decoder", reuse=tf.AUTO_REUSE):
        target_inputter.build()
    target_vocab_rev = target_inputter.vocabulary_lookup_reverse()

    ### TF BeamSearch Decoding ###
    tf_beamsearch_target_ids, tf_beamsearch_target_length, _, _, _ = tf_beamsearch_decoding(
        tf_encoder_result,
        memory_sequence_length,
        target_inputter.embedding,
        decoding_beamsearch_args,
        decoder_type=0)

    # tf_beamsearch_target_tokens: [batch_size, beam_width, seq_len]
    tf_beamsearch_target_tokens = target_vocab_rev.lookup(
        tf.cast(tf_beamsearch_target_ids, tf.int64))
    tf_beamsearch_target_length = tf.minimum(
        tf_beamsearch_target_length + 1,
        tf.shape(tf_beamsearch_target_ids)[-1])
    ### end of TF BeamSearch Decoding ###

    ### TF Sampling Decoding ###
    tf_sampling_target_ids, tf_sampling_target_length = tf_sampling_decoding(
        tf_encoder_result,
        memory_sequence_length,
        target_inputter.embedding,
        decoding_sampling_args,
        decoder_type=0)

    # tf_sampling_target_tokens: [batch_size, seq_len]
    tf_sampling_target_tokens = target_vocab_rev.lookup(
        tf.cast(tf_sampling_target_ids, tf.int64))
    tf_sampling_target_length = tf.minimum(
        tf_sampling_target_length + 1,
        tf.shape(tf_sampling_target_ids)[-1])
    ### end of TF BeamSearch Decoding ###

    ### OP BeamSearch Decoder ###
    op_decoder_beamsearch_target_ids, op_decoder_beamsearch_target_length, _, _, _ = tf_beamsearch_decoding(
        tf_encoder_result,
        memory_sequence_length,
        target_inputter.embedding,
        decoding_beamsearch_args,
        decoder_type=1)

    # op_decoder_beamsearch_target_tokens: [batch_size, beam_width, seq_len]
    op_decoder_beamsearch_target_tokens = target_vocab_rev.lookup(
        tf.cast(op_decoder_beamsearch_target_ids, tf.int64))
    op_decoder_beamsearch_target_length = tf.minimum(
        op_decoder_beamsearch_target_length + 1,
        tf.shape(op_decoder_beamsearch_target_ids)[-1])
    ### end of OP BeamSearch Decoder ###

    ### OP Sampling Decoder ###
    op_decoder_sampling_target_ids, op_decoder_sampling_target_length = tf_sampling_decoding(
        tf_encoder_result,
        memory_sequence_length,
        target_inputter.embedding,
        decoding_sampling_args,
        decoder_type=1)

    op_decoder_sampling_target_tokens = target_vocab_rev.lookup(
        tf.cast(op_decoder_sampling_target_ids, tf.int64))
    op_decoder_sampling_target_length = tf.minimum(
        op_decoder_sampling_target_length + 1,
        tf.shape(op_decoder_sampling_target_ids)[-1])
    ### end of OP BeamSearch Decoder ###

    ### Prepare Decoding variables for FasterTransformer  ###
    all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    decoder_var_start_id = 0

    while all_vars[decoder_var_start_id].name.find(
            "transformer/decoder") == -1:
        decoder_var_start_id += 1
    decoder_variables = all_vars[
        decoder_var_start_id +
        1:]  # decoder_var_start_id + 1 means skip the embedding table

    ### OP BeamSearch Decoding ###
    op_beamsearch_target_ids, op_beamsearch_target_length, _, _, _ = op_beamsearch_decoding(
        ft_encoder_result, memory_sequence_length, target_inputter.embedding,
        decoder_variables, decoding_beamsearch_args)

    op_beamsearch_target_tokens = target_vocab_rev.lookup(
        tf.cast(op_beamsearch_target_ids, tf.int64))
    op_beamsearch_target_length = tf.minimum(
        op_beamsearch_target_length + 1,
        tf.shape(op_beamsearch_target_ids)[-1])
    ### end of OP BeamSearch Decoding ###

    ### OP Sampling Decoding ###
    op_sampling_target_ids, op_sampling_target_length = op_sampling_decoding(
        ft_encoder_result, memory_sequence_length, target_inputter.embedding,
        decoder_variables, decoding_sampling_args)

    op_sampling_target_tokens = target_vocab_rev.lookup(
        tf.cast(op_sampling_target_ids, tf.int64))
    op_sampling_target_length = tf.minimum(
        op_sampling_target_length + 1,
        tf.shape(op_sampling_target_ids)[-1])
    ### end of OP Sampling Decoding ###

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    time_args = args_dict['test_time']

    class TranslationResult(object):
        def __init__(self, token_op, length_op, name):
            self.token_op = token_op
            self.length_op = length_op
            self.name = name
            self.file_name = name + ".txt"

            self.token_list = []
            self.length_list = []
            self.batch_num = 0
            self.execution_time = 0.0  # seconds
            self.sentence_num = 0
            self.bleu_score = None

    translation_result_list = []

    if time_args.find("0") != -1:
        translation_result_list.append(
            TranslationResult(tf_beamsearch_target_tokens,
                              tf_beamsearch_target_length,
                              "tf-decoding-beamsearch"))
    if time_args.find("1") != -1:
        translation_result_list.append(
            TranslationResult(op_decoder_beamsearch_target_tokens,
                              op_decoder_beamsearch_target_length,
                              "op-decoder-beamsearch"))
    if time_args.find("2") != -1:
        translation_result_list.append(
            TranslationResult(op_beamsearch_target_tokens,
                              op_beamsearch_target_length,
                              "op-decoding-beamsearch"))
    if time_args.find("3") != -1:
        translation_result_list.append(
            TranslationResult(tf_sampling_target_tokens,
                              tf_sampling_target_length,
                              "tf-decoding-sampling"))
    if time_args.find("4") != -1:
        translation_result_list.append(
            TranslationResult(op_decoder_sampling_target_tokens,
                              op_decoder_sampling_target_length,
                              "op-decoder-sampling"))
    if time_args.find("5") != -1:
        translation_result_list.append(
            TranslationResult(op_sampling_target_tokens,
                              op_sampling_target_length,
                              "op-decoding-sampling"))

    float_var_list = []
    half_var_list = []
    for var in tf.global_variables()[:-1]:
        if var.dtype.base_dtype == tf.float32:
            float_var_list.append(var)
        elif var.dtype.base_dtype == tf.float16:
            half_var_list.append(var)

    if (len(translation_result_list) == 0):
        print("[WARNING] No put any test cases.")

    cuda_profiler = cudaProfiler()
    cuda_profiler.start()
    for i in range(len(translation_result_list)):
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.tables_initializer())
            sess.run(iterator.initializer)

            if (len(float_var_list) > 0):
                float_saver = tf.train.Saver(float_var_list)
                float_saver.restore(sess, "translation/ckpt/model.ckpt-500000")
            if (len(half_var_list) > 0):
                half_saver = tf.train.Saver(half_var_list)
                half_saver.restore(sess,
                                   "translation/ckpt/fp16_model.ckpt-500000")

            t1 = datetime.now()
            while True:
                try:
                    batch_tokens, batch_length = sess.run([
                        translation_result_list[i].token_op,
                        translation_result_list[i].length_op
                    ])
                    for tokens, length in zip(batch_tokens, batch_length):
                        if translation_result_list[i].name.find(
                                "beamsearch") != -1:
                            translation_result_list[i].token_list.append(
                                b" ".join(tokens[0][:length[0] -
                                                    2]).decode("UTF-8"))
                        else:
                            translation_result_list[i].token_list.append(
                                b" ".join(tokens[:length - 2]).decode("UTF-8"))
                    translation_result_list[i].batch_num += 1
                except tf.errors.OutOfRangeError:
                    break
            t2 = datetime.now()
            time_sum = (t2 - t1).total_seconds()
            translation_result_list[i].execution_time = time_sum

            with open(translation_result_list[i].file_name, "w") as file_b:
                for s in translation_result_list[i].token_list:
                    file_b.write(s)
                    file_b.write("\n")

            ref_file_path = "./.ref_file.txt"
            os.system("head -n %d %s > %s" %
                      (len(translation_result_list[i].token_list),
                       args_dict['target'], ref_file_path))
            translation_result_list[i].bleu_score = bleu_score(
                translation_result_list[i].file_name, ref_file_path)
            os.system("rm {}".format(ref_file_path))

            time.sleep(60)
    cuda_profiler.stop()

    for t in translation_result_list:
        print(
            "[INFO] {} translates {} batches taking {:.2f} sec to translate {} tokens, BLEU score: {:.2f}, {:.0f} tokens/sec."
            .format(t.name, t.batch_num, t.execution_time,
                    t.bleu_score.sys_len, t.bleu_score.score,
                    t.bleu_score.sys_len / t.execution_time))

    return translation_result_list
    op_encoder_result = op_encoder(inputs=from_tensor,
                                   encoder_args=encoder_args,
                                   attention_mask=attention_mask,
                                   encoder_vars_dict=encoder_variables_dict,
                                   sequence_length=memory_sequence_length)
    op_encoder_result = tf.reshape(
        op_encoder_result, [batch_size, max_seq_len, encoder_hidden_dim])
    op_encoder_result = op_encoder_result * tf.expand_dims(tf.sequence_mask(
        memory_sequence_length, maxlen=max_seq_len, dtype=tf_datatype),
                                                           axis=-1)

    finalized_op_output_ids, finalized_op_sequence_lengths, op_output_ids, \
        op_parent_ids, op_sequence_lengths = op_beamsearch_decoding(op_encoder_result,
                                                                    memory_sequence_length,
                                                                    embedding_table,
                                                                    decoder_variables,
                                                                    decoding_args)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())

        finalized_tf_output_ids_result, tf_output_ids_result, tf_parent_ids_result, \
            tf_sequence_lengths_result = sess.run(
                [finalized_tf_output_ids, tf_output_ids, tf_parent_ids, tf_sequence_lengths])
        finalized_op_output_ids_result, op_output_ids_result, op_parent_ids_result, \
            op_sequence_lengths_result = sess.run(
                [finalized_op_output_ids, op_output_ids, op_parent_ids, op_sequence_lengths])
    ### Prepare Decoding variables for FasterTransformer  ###
    all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    decoder_var_start_id = 0

    while all_vars[decoder_var_start_id].name.find(
            "transformer/decoder") == -1:
        decoder_var_start_id += 1
    encoder_variables = all_vars[:decoder_var_start_id]
    decoder_variables = all_vars[
        decoder_var_start_id +
        1:]  # decoder_var_start_id + 1 means skip the embedding table

    ### OP BeamSearch Decoding ###
    op_beamsearch_target_ids, op_beamsearch_target_length, _, _, _ = op_beamsearch_decoding(
        tf_encoder_result, memory_sequence_length, target_inputter.embedding,
        decoder_variables, decoding_beamsearch_args)

    op_beamsearch_target_tokens = target_vocab_rev.lookup(
        tf.cast(op_beamsearch_target_ids, tf.int64))
    op_beamsearch_target_length = tf.minimum(
        op_beamsearch_target_length + 1,
        tf.shape(op_beamsearch_target_ids)[-1])
    ### end of OP BeamSearch Decoding ###

    ### OP Sampling Decoding ###
    op_sampling_target_ids, op_sampling_target_length = op_sampling_decoding(
        tf_encoder_result, memory_sequence_length, target_inputter.embedding,
        decoder_variables, decoding_sampling_args)

    op_sampling_target_tokens = target_vocab_rev.lookup(