コード例 #1
0
ファイル: flask_s2s.py プロジェクト: spacetiller/mortar
def abc(model=model, sess=sess):
    error = None
    if request.method == 'POST':
        #sentence = request.json['sentence']
        sentence = request.form['siri']
        print("----------siri input words ------" + sentence)

        class TestBucket(object):
            def __init__(self, sentence):
                self.sentence = sentence

            def random(self):
                return sentence, ''

        bucket_id = min(
            [b for b in range(len(buckets)) if buckets[b][0] > len(sentence)])
        data, _ = model.get_batch_data({bucket_id: TestBucket(sentence)},
                                       bucket_id)
        encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
            {bucket_id: TestBucket(sentence)}, bucket_id, data)
        _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                         decoder_weights, bucket_id, True)
        outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        ret = data_utils.indice_sentence(outputs)
        print("----------p4bot output words ------" + ret)
        return ret
    return render_template('p4bot.html', error=error)
コード例 #2
0
ファイル: seq2seq.py プロジェクト: MSintern/msbot_seq2seq
def decode_line(sess,model,sentence):
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence
        def random(self):
            return sentence, ''
    bucket_id = min([
        b for b in range(len(buckets))
        if buckets[b][0] > len(sentence)
    ])
    #return '1'
    data, _ = model.get_batch_data(
        {bucket_id: TestBucket(sentence)},
        bucket_id
    )
    #return data[0][0]
    encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
        {bucket_id: TestBucket(sentence)},
        bucket_id,
        data
    )
    #return data[0][0]
    _, _, output_logits = model.step(
        sess,
        encoder_inputs,
        decoder_inputs,
        decoder_weights,
        bucket_id,
        True
    )
    #return data[0][0]
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    ret = data_utils.indice_sentence(outputs)
    return ret
コード例 #3
0
ファイル: run.py プロジェクト: spacetiller/mortar
def abc(model=model, sess=sess):
    error = None
    if request.method == 'GET':
        #sentence = request.json['sentence']
        sentence = request.args['siri']

        #print ("----------siri input words ------" + sentence)
        class TestBucket(object):
            def __init__(self, sentence):
                self.sentence = sentence

            def random(self):
                return sentence, ''

        bucket_id = min(
            [b for b in range(len(buckets)) if buckets[b][0] > len(sentence)])
        data, _ = model.get_batch_data({bucket_id: TestBucket(sentence)},
                                       bucket_id)
        encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
            {bucket_id: TestBucket(sentence)}, bucket_id, data)
        _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                         decoder_weights, bucket_id, True)
        outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        ret = data_utils.indice_sentence(outputs)
        #print ("----------p4bot output words ------" + ret)
        end = '{"code": 1000, "text": "%s" }' % (ret)
        return "successCallback" + "(" + json.loads(json.dumps(end)) + ")"
    end = '{"code": 1003, "text": "%s" }' % ("请求方法错误")
    return "successCallback" + "(" + json.loads(json.dumps(end)) + ")"
コード例 #4
0
def test():
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence
        def random(self):
            return sentence, ''
    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        
        sess.run(tf.global_variables_initializer())
        model.saver.restore(sess, os.path.join(FLAGS.model_dir, FLAGS.model_name))
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            #获取最小的分桶id
            bucket_id = min([ b for b in range(len(buckets))  if buckets[b][0] > len(sentence) ])
            #输入句子处理
            data, _ = model.get_batch_data( {bucket_id: TestBucket(sentence)}, bucket_id )
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch( {bucket_id: TestBucket(sentence)},  bucket_id, data )
            _, _, output_logits = model.step(sess,encoder_inputs,decoder_inputs,decoder_weights, bucket_id,True)
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            ret = data_utils.indice_sentence(outputs)
            print(ret)
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
コード例 #5
0
ファイル: s2s.py プロジェクト: zheroic/Seq2Seq_Chatbot_QA
def test():
    print("test mode")
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence
        def random(self):
            return sentence, ''
    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        sess.run(tf.initialize_all_variables())
        
        ckpt =tf.train.get_checkpoint_state(FLAGS.model_dir)
        if ckpt == None or ckpt.model_checkpoint_path == None:
            print('restore model fail')
            return 

        print('restore model file %s' % ckpt.model_checkpoint_path)
        print(ckpt.model_checkpoint_path)
        
        model.saver.restore(sess,ckpt.model_checkpoint_path)
        print("Input 'exit()' to exit test mode!")
        sys.stdout.write("me > ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        if "exit()" in sentence:
            sentence = False
        while sentence:
            bucket_id = min([
                b for b in range(len(buckets))
                if buckets[b][0] > len(sentence)
            ])
            data, _ = model.get_batch_data(
                {bucket_id: TestBucket(sentence)},
                bucket_id
            )
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                {bucket_id: TestBucket(sentence)},
                bucket_id,
                data
            )
            _, _, output_logits = model.step(
                sess,
                encoder_inputs,
                decoder_inputs,
                decoder_weights,
                bucket_id,
                True
            )
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            ret = data_utils.indice_sentence(outputs)
            print("AI >", ret)
            print("me > ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
            if "exit()" in sentence:
                break
コード例 #6
0
ファイル: s2s.py プロジェクト: exueyuan/ChatBotYellowChicken
def test_bleu(count):
    """测试bleu, 这个方法我们不看"""
    print("bleu test mode")
    from nltk.translate.bleu_score import sentence_bleu
    from tqdm import tqdm
    # 准备数据
    print('准备数据')
    bucket_dbs = data_utils.read_bucket_dbs(FLAGS.buckets_dir)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} 中有数据 {} 条'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('共有数据 {} 条'.format(total_size))
    # bleu设置0的话,默认对所有样本采样
    if count <= 0:
        count = total_size
    buckets_scale = [
        sum(bucket_sizes[:i + 1]) / total_size
        for i in range(len(bucket_sizes))
    ]
    with tf.Session() as sess:
        #  构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        sess.run(tf.initialize_all_variables())
        model.saver.restore(sess,
                            os.path.join(FLAGS.model_dir, FLAGS.model_name))

        total_score = 0.0
        for i in tqdm(range(count)):
            # 选择一个要训练的bucket
            random_number = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(buckets_scale))
                if buckets_scale[i] > random_number
            ])
            data, _ = model.get_batch_data(bucket_dbs, bucket_id)
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                bucket_id, data)
            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, decoder_weights,
                                             bucket_id, True)
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            ask, _ = data[0]
            all_answers = bucket_dbs[bucket_id].all_answers(ask)
            ret = data_utils.indice_sentence(outputs)
            if not ret:
                continue
            references = [list(x) for x in all_answers]
            score = sentence_bleu(references, list(ret), weights=(1.0, ))
            total_score += score
        print('BLUE: {:.2f} in {} samples'.format(total_score / count * 10,
                                                  count))
コード例 #7
0
 def res(sentence):
     bucket_id = min(
         [b for b in range(len(buckets)) if buckets[b][0] > len(sentence)])
     data, _ = model.get_batch_data({bucket_id: TestBucket(sentence)},
                                    bucket_id)
     encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
         {bucket_id: TestBucket(sentence)}, bucket_id, data)
     _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                      decoder_weights, bucket_id, True)
     outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
     ret = data_utils.indice_sentence(outputs)
     return ret
コード例 #8
0
def test_bleu(count):
    u'测试bleu'
    from nltk.translate.bleu_score import sentence_bleu
    from tqdm import tqdm
    print(u'准备数据')
    bucket_dbs = data_utils.read_bucket_dbs(
        FLAGS.buckets_dir)  #FLAGS.buckets_dir
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print(u'bucket {} 中有数据 {} 条'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print(u'共有数据 {} 条'.format(total_size))
    if (count <= 0):
        count = total_size
    buckets_scale = [(sum(bucket_sizes[:(i + 1)]) / total_size)
                     for i in range(len(bucket_sizes))]
    with tf.Session() as sess:
        model = create_model(sess, True)
        model.batch_size = 1
        sess.run(tf.initialize_all_variables())
        model.saver.restore(sess,
                            os.path.join(FLAGS.model_dir, FLAGS.model_name))
        total_score = 0.0
        for i in tqdm(range(count)):
            random_number = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(buckets_scale))
                if (buckets_scale[i] > random_number)
            ])
            (data, _) = model.get_batch_data(bucket_dbs, bucket_id)
            (encoder_inputs, decoder_inputs,
             decoder_weights) = model.get_batch(bucket_dbs, bucket_id, data)
            (_, _, output_logits) = model.step(sess, encoder_inputs,
                                               decoder_inputs, decoder_weights,
                                               bucket_id, True)
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            (ask, _) = data[0]
            all_answers = bucket_dbs[bucket_id].all_answers(ask)
            ret = data_utils.indice_sentence(outputs)
            if (not ret):
                continue
            references = [list(x) for x in all_answers]
            score = sentence_bleu(references, list(ret), weights=(1.0, ))
            total_score += score
        print(u'BLUE: {:.2f} in {} samples'.format(
            ((total_score / count) * 10), count))
コード例 #9
0
 def predict(self, buckets, sentence, bucket, sess):
     # 判断参数是位于哪个bucket中的
     bucket_id = min(
         [b for b in range(len(buckets)) if buckets[b][0] > len(sentence)])
     # 输入句子处理
     data, _ = self.get_batch_data({bucket_id: bucket}, bucket_id)
     # 编码解码输入
     encoder_inputs, decoder_inputs, decoder_weights = self.get_batch(
         bucket_id, data)
     # 模型执行
     _, _, output_logits = self.step(sess, encoder_inputs, decoder_inputs,
                                     decoder_weights, bucket_id, True)
     outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
     ret = data_utils.indice_sentence(outputs)
     return ret
コード例 #10
0
def test():
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence

        def random(self):
            return sentence, ''

    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量

        sess.run(tf.global_variables_initializer())
        model.saver.restore(sess,
                            os.path.join(FLAGS.model_dir, FLAGS.model_name))
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            #获取最小的分桶id
            bucket_id = min([
                b for b in range(len(buckets)) if buckets[b][0] > len(sentence)
            ])
            #输入句子处理,获取问答对和答问对
            data, _ = model.get_batch_data({
                bucket_id: TestBucket(sentence)
            }, bucket_id)  #正常是bucket_dbs, bucket_id,即主要为了bucket_dbs[bucket_id]
            #而这里主要是为了能构建一个空的答案,故而第一个参数制造一个字典,从而也可以使用bucket_dbs[bucket_id]
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                {bucket_id: TestBucket(sentence)}, bucket_id,
                data)  #得到encoder_inputs, decoder_inputs, decoder_weights
            _, _, output_logits = model.step(
                sess, encoder_inputs, decoder_inputs, decoder_weights,
                bucket_id, True)  #输出为None,loss和outputs,这里只取了outputs
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            #对每个输出选择最大维度的那个
            ret = data_utils.indice_sentence(outputs)
            print(ret)
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
コード例 #11
0
 def run(self):
     bucket_id = min([
         b for b in range(len(self.buckets))
         if self.buckets[b][0] > len(self.sentence)
     ])
     # 输入句子处理
     data, _ = self.model.get_batch_data(
         {bucket_id: TestBucket(self.sentence)}, bucket_id)
     # 编码解码输入
     encoder_inputs, decoder_inputs, decoder_weights = self.model.get_batch(
         bucket_id, data)
     # 模型执行
     _, _, output_logits = self.model.step(self.sess_gen, encoder_inputs,
                                           decoder_inputs, decoder_weights,
                                           bucket_id, True)
     outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
     ret = data_utils.indice_sentence(outputs)
     self.InfoSignal.emit(ret)
コード例 #12
0
def test():
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence
        def random(self):
            return sentence, ''
    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        sess.run(tf.initialize_all_variables())
        model.saver.restore(sess, os.path.join(FLAGS.model_dir, FLAGS.model_name))
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            bucket_id = min([
                b for b in range(len(buckets))
                if buckets[b][0] > len(sentence)
            ])
            data, _ = model.get_batch_data(
                {bucket_id: TestBucket(sentence)},
                bucket_id
            )
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                {bucket_id: TestBucket(sentence)},
                bucket_id,
                data
            )
            _, _, output_logits = model.step(
                sess,
                encoder_inputs,
                decoder_inputs,
                decoder_weights,
                bucket_id,
                True
            )
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            ret = data_utils.indice_sentence(outputs)
            print(ret)
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
コード例 #13
0
ファイル: s2s.py プロジェクト: jieliorz/chatbot-1
def serve():
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence

        def random(self):
            return sentence, ''

    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        sess.run(tf.initialize_all_variables())
        model.saver.restore(sess,
                            os.path.join(FLAGS.model_dir, FLAGS.model_name))

        #开启socket
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        address = ('127.0.0.1', FLAGS.port)
        s.bind(address)
        print("已启动,端口号为", FLAGS.port)
        while True:
            data, addr = s.recvfrom(2048)
            sentence = data.decode("utf-8")
            print("received:", sentence, "from", addr)

            bucket_id = min([
                b for b in range(len(buckets)) if buckets[b][0] > len(sentence)
            ])
            data, _ = model.get_batch_data({bucket_id: TestBucket(sentence)},
                                           bucket_id)
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                {bucket_id: TestBucket(sentence)}, bucket_id, data)
            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, decoder_weights,
                                             bucket_id, True)
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            ret = data_utils.indice_sentence(outputs)
            print("return ", ret)
            s.sendto(('%s' % ret).encode(), addr)
        s.close()
コード例 #14
0
def test():
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence

        def random(self):
            return (sentence, u'')

    with tf.Session() as sess:
        model = create_model(sess, True)
        model.batch_size = 1
        sess.run(tf.initialize_all_variables())
        model.saver.restore(sess,
                            os.path.join(FLAGS.model_dir, FLAGS.model_name))
        sys.stdout.write(u'> ')
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            bucket_id = min([
                b for b in range(len(buckets))
                if (buckets[b][0] > len(sentence))
            ])
            (data,
             _) = model.get_batch_data({
                 bucket_id: TestBucket(sentence),
             }, bucket_id)
            (encoder_inputs, decoder_inputs,
             decoder_weights) = model.get_batch(
                 {
                     bucket_id: TestBucket(sentence),
                 }, bucket_id, data)
            (_, _, output_logits) = model.step(sess, encoder_inputs,
                                               decoder_inputs, decoder_weights,
                                               bucket_id, True)
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            ret = data_utils.indice_sentence(outputs)
            print(ret)
            print(u'> ', end=u'')
            sys.stdout.flush()
            sentence = sys.stdin.readline()
コード例 #15
0
def test(sentence):
    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence
        def random(self):
            return sentence, ''
    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        #sess.run(tf.initialize_all_variables())
        model.saver.restore(sess, save_path='/home/Seq2Seq/model3/model')
        bucket_id = min([
            b for b in range(len(buckets))
            if buckets[b][0] > len(sentence)
        ])
        data, _ = model.get_batch_data(
            {bucket_id: TestBucket(sentence)},
            bucket_id
        )
        encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
            {bucket_id: TestBucket(sentence)},
            bucket_id,
            data
        )
        _, _, output_logits = model.step(
            sess,
            encoder_inputs,
            decoder_inputs,
            decoder_weights,
            bucket_id,
            True
        )
        outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        ret = data_utils.indice_sentence(outputs)
        with open('/home/Seq2Seq/1.txt','w') as f:
            f.write(ret)
コード例 #16
0
ファイル: run.py プロジェクト: spacetiller/mortar
def abcd(model=model, sess=sess):
    sentence = '你好啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊'

    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence

        def random(self):
            return sentence, ''

    #print(time.time())
    bucket_id = min(
        [b for b in range(len(buckets)) if buckets[b][0] > len(sentence)])
    #print(time.time())
    data, _ = model.get_batch_data({bucket_id: TestBucket(sentence)},
                                   bucket_id)
    #print(time.time())
    encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
        {bucket_id: TestBucket(sentence)}, bucket_id, data)
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                     decoder_weights, bucket_id, True)
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    ret = data_utils.indice_sentence(outputs)
    return ret
コード例 #17
0
def test_general_ques():
    print("test mode automatically with the general question")

    class TestBucket(object):
        def __init__(self, sentence):
            self.sentence = sentence

        def random(self):
            return sentence, ''

    with tf.Session() as sess:
        #  构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        sess.run(tf.initialize_all_variables())

        ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
        if ckpt == None or ckpt.model_checkpoint_path == None:
            print('restore model fail')
            return

        print('restore model file %s' % ckpt.model_checkpoint_path)
        print(ckpt.model_checkpoint_path)

        model.saver.restore(sess, ckpt.model_checkpoint_path)
        # print("Input 'exit()' to exit test mode!")
        # sys.stdout.write("me > ")
        # sys.stdout.flush()
        # sentence = sys.stdin.readline()
        # if "exit()" in sentence:
        #     sentence = False
        # now = datetime.datetime.now()
        output_name = 'training' + ckpt.model_checkpoint_path.split(
            '/')[-2][-1] + '-' + ckpt.model_checkpoint_path.split(
                '/')[-1] + '_qa_test_' + datetime.datetime.now().strftime(
                    "%Y-%B-%d-%I-%M") + '.txt'
        output_file = open(
            '/home/honghaier1688/workspaces/Seq2Seq_Chatbot_QA-master/test_auto_result/'
            + output_name, "w")
        output_file.write('restore model file %s \n' %
                          ckpt.model_checkpoint_path)
        with open(
                '/home/honghaier1688/workspaces/Seq2Seq_Chatbot_QA-master/test_auto_result/general_ques.txt'
        ) as f:
            while True:
                sentence = f.readline()
                if "exit()" in sentence:
                    break
                if sentence:
                    bucket_id = min([
                        b for b in range(len(buckets))
                        if buckets[b][0] > len(sentence)
                    ])
                    data, _ = model.get_batch_data(
                        {bucket_id: TestBucket(sentence)}, bucket_id)
                    encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                        {bucket_id: TestBucket(sentence)}, bucket_id, data)
                    _, _, output_logits = model.step(sess, encoder_inputs,
                                                     decoder_inputs,
                                                     decoder_weights,
                                                     bucket_id, True)
                    outputs = [
                        int(np.argmax(logit, axis=1))
                        for logit in output_logits
                    ]
                    ret = data_utils.indice_sentence(outputs)
                    segments = "\nme > " + sentence + "AI > " + ret + '\n'
                    output_file.write(segments)
                # print("AI >", ret)
                # print("me > ", end="")
                # sys.stdout.flush()
                # sentence = sys.stdin.readline()
                else:
                    break
        output_file.close()
コード例 #18
0
def test_bleu(count):
    """测试bleu"""
    from nltk.translate.bleu_score import sentence_bleu
    from tqdm import tqdm
    # 准备数据
    print('准备数据')
    bucket_dbs = data_utils.read_bucket_dbs(FLAGS.buckets_dir)
    bucket_sizes = []
    for i in range(len(buckets)):
        bucket_size = bucket_dbs[i].size
        bucket_sizes.append(bucket_size)
        print('bucket {} 中有数据 {} 条'.format(i, bucket_size))
    total_size = sum(bucket_sizes)
    print('共有数据 {} 条'.format(total_size))
    # bleu设置0的话,默认对所有样本采样
    if count <= 0:
        count = total_size
    buckets_scale = [
        sum(bucket_sizes[:i + 1]) / total_size
        for i in range(len(bucket_sizes))
    ]
    with tf.Session() as sess:
        # 构建模型
        model = create_model(sess, True)
        model.batch_size = 1
        # 初始化变量
        sess.run(tf.initialize_all_variables())
        model.saver.restore(sess, os.path.join(FLAGS.model_dir, FLAGS.model_name))

        total_score = 0.0
        for i in tqdm(range(count)):
            # 选择一个要训练的bucket
            random_number = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(buckets_scale))
                if buckets_scale[i] > random_number
            ])
            data, _ = model.get_batch_data(
                bucket_dbs,
                bucket_id
            )
            encoder_inputs, decoder_inputs, decoder_weights = model.get_batch(
                bucket_dbs,
                bucket_id,
                data
            )
            _, _, output_logits = model.step(
                sess,
                encoder_inputs,
                decoder_inputs,
                decoder_weights,
                bucket_id,
                True
            )
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            ask, _ = data[0]
            all_answers = bucket_dbs[bucket_id].all_answers(ask)
            ret = data_utils.indice_sentence(outputs)
            if not ret:
                continue
            references = [list(x) for x in all_answers]
            score = sentence_bleu(
                references,
                list(ret),
                weights=(1.0,)
            )
            total_score += score
        print('BLUE: {:.2f} in {} samples'.format(total_score / count * 10, count))