Exemplo n.º 1
0
def decode():
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        en_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.en" % FLAGS.en_vocab_size)
        fr_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.fr" % FLAGS.fr_vocab_size)
        src_vocab, _ = data_tools.initialize_vocabulary(en_vocab_path)
        _, rev_fr_vocab = data_tools.initialize_vocabulary(fr_vocab_path)

        # 读取测试源端和目标端句子
        dev_src_texts, dev_tgt_texts = data_tools.do_test(FLAGS.data_dir)

        outputPath = os.path.join(FLAGS.data_dir, "dev_data.result")
        correct = 0
        with gfile.GFile(outputPath, mode="w") as outputfile:
            for i in xrange(len(dev_src_texts)):
                src_text = dev_src_texts[i]

                # 获得原始输入句子的索引id
                token_ids = data_tools.sentence_to_token_ids(tf.compat.as_bytes(src_text), src_vocab)

                # 获得句子分配的桶位置
                bucket_id = min([b for b in xrange(len(_buckets))
                                 if _buckets[b][0] > len(token_ids)])

                # 根据相应的桶和输入信息,获得seq2seq的编码器的输入,解码器的输入和目标权值
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                    {bucket_id: [(token_ids, [])]}, bucket_id)

                # Get output logits for the sentence.
                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                                 target_weights, bucket_id, True)
                # This is a greedy decoder - outputs are just argmaxes of output_logits.
                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
                # If there is an EOS symbol in outputs, cut them at that point.
                if data_tools.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_tools.EOS_ID)]
                # Print out French sentence corresponding to outputs.
                out = " ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs])
                out = "M " + out + "\n"
                if out == dev_tgt_texts[i]:
                    correct += 1
                outputfile.write('intput: %s' % src_text)
                outputfile.write('output: %s' % out)
                outputfile.write('target: %s' % dev_tgt_texts[i])
                outputfile.write('\n')
            precision = correct / len(dev_tgt_texts)
            outputfile.write('precision = %.3f' % precision)
Exemplo n.º 2
0
def decode():
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        src_vocab_path = os.path.join(
            args.data_dir,
            "vocab%d." % args.src_vocab_size + args.src_extension)
        tgt_vocab_path = os.path.join(
            args.data_dir,
            "vocab%d." % args.tgt_vocab_size + args.tgt_extension)
        src_vocab, _ = data_tools.initialize_vocabulary(src_vocab_path)
        _, rev_tgt_vocab = data_tools.initialize_vocabulary(tgt_vocab_path)

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            # Get token-ids for the input sentence.
            token_ids = data_tools.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), src_vocab)
            # Which bucket does it belong to?
            bucket_id = min([
                b for b in xrange(len(_buckets))
                if _buckets[b][0] > len(token_ids)
            ])
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)
            # Get output logits for the sentence.
            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, target_weights,
                                             bucket_id, True)
            # This is a greedy decoder - outputs are just argmaxes of output_logits.
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_tools.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_tools.EOS_ID)]
            # Print out target sentence corresponding to outputs.
            print(" ".join([
                tf.compat.as_str(rev_tgt_vocab[output]) for output in outputs
            ]))
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Exemplo n.º 3
0
def decode_once2(sess, model, input):

    # model = Singleton.get_instance(sess)
    model.batch_size = 1  # We decode one sentence at a time.
    # Load vocabularies.
    en_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.en" % FLAGS.en_vocab_size)
    fr_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.fr" % FLAGS.fr_vocab_size)
    en_vocab, _ = data_tools.initialize_vocabulary(en_vocab_path)
    _, rev_fr_vocab = data_tools.initialize_vocabulary(fr_vocab_path)

    sentence = input
    # Get token-ids for the input sentence.
    token_ids = data_tools.sentence_to_token_ids(tf.compat.as_bytes(sentence),
                                                 en_vocab)
    # print("token:%s"%token_ids)
    # print("token-len:%s"%len(token_ids))
    # Which bucket does it belong to?
    l = [b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)]
    # print("l", l)
    if l:
        bucket_id = min(l)
    else:
        bucket_id = 3

    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
        {bucket_id: [(token_ids, [])]}, bucket_id)
    # Get output logits for the sentence.
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                     target_weights, bucket_id, True)
    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

    # If there is an EOS symbol in outputs, cut them at that point.
    if data_tools.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_tools.EOS_ID)]
    # Print out French sentence corresponding to outputs.
    outputresult = " ".join(
        [tf.compat.as_str(rev_fr_vocab[output]) for output in outputs])

    return outputresult
Exemplo n.º 4
0
def decode():
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 100  # We decode one sentence at a time.

        # Load vocabularies.
        en_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.en" % FLAGS.en_vocab_size)
        fr_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.fr" % FLAGS.fr_vocab_size)
        src_vocab, _ = data_tools.initialize_vocabulary(en_vocab_path)
        _, rev_fr_vocab = data_tools.initialize_vocabulary(fr_vocab_path)

        # 读取测试源端和目标端句子
        dev_src_texts, dev_tgt_texts = data_tools.do_test(FLAGS.data_dir)

        outputPath = os.path.join(FLAGS.data_dir, "dev_data.result")
        correct = 0
        with gfile.GFile(outputPath, mode="w") as outputfile:
            for i in xrange(len(dev_src_texts)):
                src_text = dev_src_texts[i]

                # 获得原始输入句子的索引id
                token_ids = data_tools.sentence_to_token_ids(
                    tf.compat.as_bytes(src_text), src_vocab)

                # 获得句子分配的桶位置 TODO
                bucket_id = min([
                    b for b in xrange(len(_buckets))
                    if _buckets[b][0] > len(token_ids)
                ])

                # 根据相应的桶和输入信息,获得seq2seq的编码器的输入,解码器的输入和目标权值 TODO
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                    {bucket_id: [(token_ids, [])]}, bucket_id)

                # 获得decoder的softmax(logits)输出 TODO
                _, _, output_logits = model.step(sess, encoder_inputs,
                                                 decoder_inputs,
                                                 target_weights, bucket_id,
                                                 True)
                # 贪心的方法输出最佳的一个 TODO 可以使用beam search选择多个
                # outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
                # 每行元素的最大值,注意有可能多个
                outputs = [
                    int(np.argmax(logit, axis=1)[0]) for logit in output_logits
                ]

                # 预测到了EOS就结束
                if data_tools.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_tools.EOS_ID)]

                # 输出对应的源端句子
                out = "".join([
                    tf.compat.as_str(rev_fr_vocab[output])
                    for output in outputs
                ])
                out += "\n"

                # 准确率的统计,CTC完整匹配
                if out == dev_tgt_texts[i]:
                    correct += 1

                outputfile.write('intput: %s' % src_text)
                outputfile.write('output: %s' % out)
                outputfile.write('target: %s' % dev_tgt_texts[i])
                outputfile.write('\n')

            precision = correct / len(dev_tgt_texts)
            outputfile.write('precision = %.3f' % precision)