Пример #1
0
    def infer(self, sess, batch, id2word, isBatch=True, beam_szie=5):
        def predict_ids_to_seq(predict_ids, id2word, beam_szie):
            '''
            将beam_search返回的结果转化为字符串
            :param predict_ids: 列表,长度为batch_size,每个元素都是decode_len*beam_size的数组
            :param id2word: vocab字典
            :return:
            '''
            for single_predict in predict_ids:
                for i in range(beam_szie):
                    predict_list = np.ndarray.tolist(single_predict[:, :, i])
                    predict_seq = [id2word[idx] for idx in predict_list[0]]
                    return " ".join(predict_seq)

        if isBatch == False:
            batch = sentence2enco(inputData, self.word_to_idx, self.lang)

        #infer阶段只需要运行最后的结果,不需要计算loss,所以feed_dict只需要传入encoder_input相应的数据即可
        feed_dict = {
            self.encoder_inputs: batch.encoder_inputs,
            self.encoder_inputs_length: batch.encoder_inputs_length,
            self.keep_prob_placeholder: 1.0,
            self.batch_size: len(batch.encoder_inputs)
        }
        predict = sess.run([self.decoder_predict_decode], feed_dict=feed_dict)
        return predict_ids_to_seq(predict, id2word, beam_szie)
Пример #2
0
def predict():
    with tf.Session() as sess:
        model = Seq2SeqModel(flags,mode='predict',beam_search=True)
        ckpt = tf.train.get_checkpoint_state(flags.model_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('Reloading model parameters...')
            model.saver.restore(sess,ckpt.model_checkpoint_path)
        else:
            raise ValueError('No such file:[{}]'.format(flags.model_dir))
        sys.stdout.write(">")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            batch = sentence2enco(sentence,model.word2id)
            predict_ids = model.infer(sess,batch)
            predict_ids_seq(predict_ids,model.id2word,model.beam_size)
            print(">")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Пример #3
0
        # tf.train.get_checkpoint_state函数通过checkpoint文件找到模型文件名
        ckpt = tf.train.get_checkpoint_state(model_dir)
        # 如果模型存在
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('Reloading model parameters..')
            # 使用saver.restore()方法恢复变量
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:  # 如果模型不存在
            raise ValueError('No such file:[{}]'.format(model_dir))  # 报错

        # 打印一'>',提示用户输入句子
        # sys.stdout.write()的功能大概是不带回车'\n'的print()
        sys.stdout.write("> ")
        # sys.stdout带有缓冲区,使用sys.stdout.flush()使其立即输出
        sys.stdout.flush()
        # sys.stdin.readline()用于读取一行输入
        sentence = sys.stdin.readline()
        while sentence:  # 只要还在输入,就持续运行
            # 将用户输入的句子切词、转换成id并放入一个batch中,具体细节进入data_helpers.py查看
            batch = sentence2enco(sentence, word_to_id)
            # 通过这一句,预测下一句
            predicted_ids = model.infer(batch)
            # 将beam_search返回的结果转化为字符串
            predict_ids_to_seq(predicted_ids, id_to_word, beam_size)
            # 再次输出'>',提示用户输入下一句子
            print("> ")
            # 使用sys.stdout.flush()使其立即输出
            sys.stdout.flush()
            # sys.stdin.readline()再次读取一行输入
            sentence = sys.stdin.readline()
Пример #4
0
                         FLAGS.num_layers,
                         FLAGS.embedding_size,
                         FLAGS.learning_rate,
                         word2id,
                         mode='decode',
                         use_attention=True,
                         beam_search=True,
                         beam_size=5,
                         max_gradient_norm=5.0)

    ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)

    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print('Reloading model parameters..')
        model.saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        raise ValueError('No such file:[{}]'.format(FLAGS.model_dir))

    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()

    while sentence:
        batch = sentence2enco(sentence, word2id)
        predicted_ids = model.infer(sess, batch)
        # print(predicted_ids)
        predict_ids_to_seq(predicted_ids, id2word, 5)
        print("> ", "")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
Пример #5
0
    :return:
    '''
    for single_predict in predict_ids:
        for i in range(beam_szie):
            predict_list = np.ndarray.tolist(single_predict[:, :, i])
            predict_seq = [id2word[idx] for idx in predict_list[0]]
            print(" ".join(predict_seq))

with tf.Session() as sess:
    model = Seq2SeqModel(FLAGS.rnn_size, FLAGS.num_layers, FLAGS.embedding_size, FLAGS.learning_rate, word2id,
                         mode='decode', use_attention=True, beam_search=True, beam_size=5, max_gradient_norm=5.0)
    ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print('Reloading model parameters..')
        model.saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        raise ValueError('No such file:[{}]'.format(FLAGS.model_dir))
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()
    while sentence:
        batch = sentence2enco(sentence, word2id)
        # 获得预测的id
        predicted_ids = model.infer(sess, batch)
        # print(predicted_ids)
        # 将预测的id转换成汉字
        predict_ids_to_seq(predicted_ids, id2word, 5)
        print("> ", "")
        sys.stdout.flush()
        sentence = sys.stdin.readline()