示例#1
0
def load_model (model_prefix, word_vocab, batch_size):
    FLAGS = load_namespace(model_prefix + ".config.json")
    label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
    num_classes = label_vocab.size()
    best_path = model_prefix + ".best.model"
    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                        dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                        lambda_l2=FLAGS.lambda_l2,
                                        context_lstm_dim=FLAGS.context_lstm_dim,
                                        is_training=False, batch_size=batch_size)
        vars_ = {}
        print ("ValidGraph Build")
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        config = tf.ConfigProto(intra_op_parallelism_threads=0,
                                inter_op_parallelism_threads=0,
                                allow_soft_placement=True)

        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, best_path)
        return valid_graph, sess, label_vocab, FLAGS
示例#2
0
    if not options.__dict__.has_key("with_target_lattice"):
        options.__dict__["with_target_lattice"] = False

    if not options.__dict__.has_key("add_first_word_prob_for_phrase"):
        options.__dict__["add_first_word_prob_for_phrase"] = False

    if not options.__dict__.has_key("pretrain_with_max_matching"):
        options.__dict__["pretrain_with_max_matching"] = False

    if not options.__dict__.has_key("reward_type"):
        options.__dict__["reward_type"] = "bleu"

    return options


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, help='Configuration file.')

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
    FLAGS, unparsed = parser.parse_known_args()

    if FLAGS.config_path is not None:
        print('Loading the configuration from ' + FLAGS.config_path)
        FLAGS = namespace_utils.load_namespace(FLAGS.config_path)

    FLAGS = enrich_options(FLAGS)

    sys.stdout.flush()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
示例#3
0
from __future__ import print_function
import argparse
import re
import os
import sys
import json
import time
import numpy as np
import codecs

import MHQA_data_stream
import namespace_utils

if __name__ == '__main__':
    # load the configuration file
    FLAGS = namespace_utils.load_namespace("/u/nalln478/ws/exp.multihop_qa/sub.MHQA/config.json")

    in_path = "/u/nalln478/ws/exp.multihop_qa/sub.MHQA/data/dev.json"
    print('Loading test set from {}.'.format(in_path))
    testset, _ = MHQA_data_stream.read_data_file(in_path, FLAGS)
    print('Number of samples: {}'.format(len(testset)))

    right = 0.0
    total = 0.0
    cands_total = 0.0
    for i, (question, passage, entity_start, entity_end, edges, candidates, ref,
            ids, candidates_str) in enumerate(testset):
        if np.argmax(len(x) for x in candidates) == ref:
            right += 1.0
        total += 1.0
        cands_total += len(candidates)
示例#4
0
                        default="prediction",
                        help='prediction or probs')

    args, unparsed = parser.parse_known_args()

    model_prefix = args.model_prefix
    in_path = args.in_path
    out_path = args.out_path
    word_vec_path = args.word_vec_path
    mode = args.mode
    out_json_path = None
    dump_prob_path = None

    # load the configuration file
    print('Loading configurations.')
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    print(FLAGS)

    with_POS = False
    if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS
    with_NER = False
    if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER
    wo_char = False
    if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char

    wo_left_match = False
    if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match

    wo_right_match = False
    if hasattr(FLAGS, 'wo_right_match'): wo_right_match = FLAGS.wo_right_match
from SentenceMatchModelGraph import SentenceMatchModelGraph
from SentenceMatchDataStream import SentenceMatchDataStream


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
    parser.add_argument('--in_path', type=str, default='../data_quora/dev.tsv', help='the path to the test file.')
    parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.')
    parser.add_argument('--word_vec_path', type=str, default='../data_quora/wordvec.txt', help='word embedding file for the input file.')

    args, unparsed = parser.parse_known_args()
    
    # load the configuration file
    print('Loading configurations.')
    options = namespace_utils.load_namespace(args.model_prefix + ".config.json")

    if args.word_vec_path is None: args.word_vec_path = options.word_vec_path

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(args.word_vec_path, fileformat='txt3')
    label_vocab = Vocab(args.model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    char_vocab = None
    if options.with_char:
        char_vocab = Vocab(args.model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
示例#6
0
文件: main.py 项目: wwt17/textsum3
import tensorflow as tf
import os
from vocab_utils import Vocab as EntVocab
import namespace_utils
from SentenceMatchModelGraph import SentenceMatchModelGraph

if __name__ == "__main__":
    config_path = "./configs/snli.sample.config"
    config_FLAGS = namespace_utils.load_namespace(config_path)
    config_FLAGS.__dict__["in_format"] = 'tsv'
    word_vec_path = config_FLAGS.word_vec_path
    log_dir = config_FLAGS.model_dir
    path_prefix = os.path.join(log_dir,
                               "SentenceMatch.{}".format(config_FLAGS.suffix))

    ent_word_vocab = EntVocab(word_vec_path, fileformat='txt3')
    print("word_vocab shape is {}".format(ent_word_vocab.word_vecs.shape))

    best_path = path_prefix + ".best.model"
    label_path = path_prefix + ".label_vocab"
    print("best_path: {}".format(best_path))

    if os.path.exists(best_path + ".index"):
        print("Loading label vocab")
        label_vocab = EntVocab(label_path, fileformat='txt2')
    else:
        raise Exception("no pretrained model")

    num_classes = label_vocab.size()
    print("Number of labels: {}".format(num_classes))
        'The type of LSTM out, mean pooled of all steps or select the last step.'
    )
    parser.add_argument('--pos_weight',
                        type=float,
                        default=0.0,
                        help='Pos weight of weighted cross entropy losses')
    parser.add_argument('--seed',
                        type=int,
                        default=2018,
                        help='Initial seed for algorithm.')

    parser.add_argument('--config_path', type=str, help='Configuration file.')

    # print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
    args, unparsed = parser.parse_known_args()
    if args.config_path is not None:
        print('Loading the configuration from ' + args.config_path)
        FLAGS = namespace_utils.load_namespace(args.config_path)
    else:
        FLAGS = args
    sys.stdout.flush()

    # enrich arguments to backwards compatibility
    FLAGS = enrich_options(FLAGS)

    if not FLAGS.with_cv:
        main(FLAGS)
    else:
        main_cv(FLAGS)

    # python SentenceMatchTrainer.py --config_path ../configs/snli.test.config
示例#8
0
    op = 'wik'
    if args.is_trec == True or args.is_trec == 'True':
        op = 'tre'
    log_dir = '../models' + op
    path_prefix = log_dir + "/SentenceMatch.normal"

    #model_prefix = args.model_prefix
    in_path = args.in_path
    word_vec_path = args.word_vec_path
    out_json_path = None
    dump_prob_path = None

    # load the configuration file
    print('Loading configurations.')
    FLAGS = namespace_utils.load_namespace(path_prefix + args.index +
                                           ".config.json")
    print(FLAGS)

    with_POS = False
    if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS
    with_NER = False
    if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER
    wo_char = False
    if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char

    wo_left_match = False
    if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match

    wo_right_match = False
    if hasattr(FLAGS, 'wo_right_match'): wo_right_match = FLAGS.wo_right_match
def get_test_result(in_p,root_path):
    print('Loading configurations.')
    model_prefix =root_path+"/stsapp/src/logs/SentenceMatch.snli"
    word_vec_path = root_path+"/stsapp/src/data/snli/wordvec.txt"



    in_path = in_p

    out_path =root_path+"/stsapp/src/result.txt"


    print("access decoder")

    options = namespace_utils.load_namespace(model_prefix + ".config.json")

    if word_vec_path is None: word_vec_path = options.word_vec_path


    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    if options.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    
    print('Build SentenceMatchDataStream ... ')
    testDataStream = SentenceMatchDataStream(in_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False, isLoop=True, isSort=True, options=options)
    print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(testDataStream.get_num_batch()))
    sys.stdout.flush()

    best_path = model_prefix + ".best.model"
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,
                                                  is_training=False, options=options)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        print("Restoring model from " + best_path)
        saver.restore(sess, best_path)
        print("DONE!")
        acc,result = train.evaluation(sess, valid_graph, testDataStream, outpath=out_path,
                                              label_vocab=label_vocab)

        print("Accuracy for test set is : ",colored(acc, 'green'),"\n")

        # print(result['probs'])

        return acc,result
示例#10
0
        assert self.out_neigh_mask.shape == self.out_neigh_indices.shape
        assert self.out_neigh_mask.shape == self.out_neigh_edges.shape

        # [batch_size, sent_len_max]
        self.src = padding_utils.pad_2d_vals(ori_batch.src, len(ori_batch.src),
                                             self.options.max_src_len)
        self.sent_inp = padding_utils.pad_2d_vals(ori_batch.sent_inp,
                                                  len(ori_batch.sent_inp),
                                                  self.options.max_answer_len)
        self.sent_out = padding_utils.pad_2d_vals(ori_batch.sent_out,
                                                  len(ori_batch.sent_out),
                                                  self.options.max_answer_len)


if __name__ == "__main__":
    FLAGS = namespace_utils.load_namespace('../config.json')
    print('Collecting vocab')
    allEdgelabels = set([line.strip().split()[0] \
            for line in open('../data/edgelabel_vocab.en', 'rU')])
    edgelabel_vocab = Vocab(voc=allEdgelabels,
                            dim=FLAGS.edgelabel_dim,
                            fileformat='build')
    word_vocab_enc = Vocab('../data/vectors.en.st', fileformat='txt2')
    word_vocab_dec = Vocab('../data/vectors.de.st', fileformat='txt2')
    print('Loading trainset')
    trainset, _, _, _, _ = read_amr_file('../data/newstest2013.tok.json',
                                         FLAGS, word_vocab_enc, word_vocab_dec,
                                         None, edgelabel_vocab)
    print('Build DataStream ... ')
    trainDataStream = G2SDataStream(trainset,
                                    word_vocab_enc,
示例#11
0
def question_gen_run(argv):
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
    #parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.')
    #parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.')
    #parser.add_argument('--mode', type=str, required=True, help='Can be `greedy` or `beam`')

    #args, unparsed = parser.parse_known_args()

    #model_prefix = args.model_prefix
    #in_path = args.in_path
    #out_path = args.out_path
    #mode = args.mode
    print(sys.argv)
    model_prefix = argv[0]
    in_path = argv[1]
    out_path = argv[2]
    mode = argv[3]

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    if FLAGS.with_word:
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    if FLAGS.with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, _ = NP2P_data_stream.read_generation_datasets_from_fof(
            in_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, _ = NP2P_data_stream.read_all_GenerationDatasets(
            in_path, isLower=FLAGS.isLower)
    else:
        testset, _ = NP2P_data_stream.read_all_GQA_questions(
            in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    if mode.find('beam') >= 0: batch_size = 1
    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True,
                                                  batch_size=batch_size)
    print('Number of instances in testDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        devDataStream.get_num_batch()))

    best_path = model_prefix + ".best.model"
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-0.01, 0.01)
        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=False,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        ## remove word _embedding
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        initializer = tf.global_variables_initializer()
        #gpu_fraction = 0.1
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
        #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess = tf.Session()
        sess.run(initializer)

        saver.restore(sess, best_path)  # restore the model

        total = 0
        correct = 0
        if mode.endswith('evaluate'):
            ref_outfile = open(out_path + ".ref", 'wt')
            pred_outfile = open(out_path + ".pred", 'wt')
        else:
            outfile = open(out_path, 'wt')
        total_num = devDataStream.get_num_batch()
        devDataStream.reset()
        for i in range(total_num):
            cur_batch = devDataStream.get_batch(i)
            if mode == 'pointwise':
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode=mode)
                for j in xrange(cur_batch.batch_size):
                    cur_total = cur_batch.answer_lengths[j]
                    cur_correct = 0
                    for k in xrange(cur_total):
                        if generator_output_idx[
                                j, k] == cur_batch.in_answer_words[j, k]:
                            cur_correct += 1.0
                    total += cur_total
                    correct += cur_correct
                    outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    outfile.write(sentences[j].encode('utf-8') + "\n")
                    outfile.write("========\n")
                outfile.flush()
                print('Current dev accuracy is %d/%d=%.2f' %
                      (correct, total, correct / float(total) * 100))
            elif mode in ['greedy', 'multinomial']:
                print('Batch {}'.format(i))
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode=mode)
                for j in xrange(cur_batch.batch_size):
                    outfile.write(
                        cur_batch.instances[j][1].ID_num.encode('utf-8') +
                        "\n")
                    outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    outfile.write(sentences[j].encode('utf-8') + "\n")
                    outfile.write("========\n")
                outfile.flush()
            elif mode == 'greedy_evaluate':
                print('Batch {}'.format(i))
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode="greedy")
                for j in xrange(cur_batch.batch_size):
                    ref_outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    pred_outfile.write(sentences[j].encode('utf-8') + "\n")
                ref_outfile.flush()
                pred_outfile.flush()
            elif mode == 'beam_evaluate':
                print('Instance {}'.format(i))
                ref_outfile.write(
                    cur_batch.instances[0][1].tokText.encode('utf-8') + "\n")
                ref_outfile.flush()
                hyps = run_beam_search(sess, valid_graph, word_vocab,
                                       cur_batch, FLAGS)
                cur_passage = cur_batch.instances[0][0]
                cur_id2phrase = None
                if FLAGS.with_phrase_projection:
                    (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0]
                cur_sent = hyps[0].idx_seq_to_string(cur_passage,
                                                     cur_id2phrase, word_vocab,
                                                     FLAGS)
                pred_outfile.write(cur_sent.encode('utf-8') + "\n")
                pred_outfile.flush()
            else:  # beam search
                print('Instance {}'.format(i))
                hyps = run_beam_search(sess, valid_graph, word_vocab,
                                       cur_batch, FLAGS)
                outfile.write(
                    "Input: " +
                    cur_batch.instances[0][0].tokText.encode('utf-8') + "\n")
                outfile.write(
                    "Truth: " +
                    cur_batch.instances[0][1].tokText.encode('utf-8') + "\n")
                for j in xrange(len(hyps)):
                    hyp = hyps[j]
                    cur_passage = cur_batch.instances[0][0]
                    cur_id2phrase = None
                    if FLAGS.with_phrase_projection:
                        (cur_phrase2id,
                         cur_id2phrase) = cur_batch.phrase_vocabs[0]
                    cur_sent = hyp.idx_seq_to_string(cur_passage,
                                                     cur_id2phrase, word_vocab,
                                                     FLAGS)
                    outfile.write("Hyp-{}: ".format(j) +
                                  cur_sent.encode('utf-8') +
                                  " {}".format(hyp.avg_log_prob()) + "\n")
                #outfile.write("========\n")
                outfile.flush()
        if mode.endswith('evaluate'):
            ref_outfile.close()
            pred_outfile.close()
        else:
            outfile.close()
示例#12
0
文件: Decoder.py 项目: yyboston/QA
    def decode(self,
               model_prefix,
               in_path,
               out_path,
               word_vec_path,
               mode,
               out_json_path=None,
               dump_prob_path=None):
        #     model_prefix = args.model_prefix
        #     in_path = args.in_path
        #     out_path = args.out_path
        #     word_vec_path = args.word_vec_path
        #     mode = args.mode
        #     out_json_path = None
        #     dump_prob_path = None

        # load the configuration file
        print('Loading configurations.')
        FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
        print(FLAGS)

        with_POS = False
        if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS
        with_NER = False
        if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER
        wo_char = False
        if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char

        wo_left_match = False
        if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match

        wo_right_match = False
        if hasattr(FLAGS, 'wo_right_match'):
            wo_right_match = FLAGS.wo_right_match

        wo_full_match = False
        if hasattr(FLAGS, 'wo_full_match'): wo_full_match = FLAGS.wo_full_match

        wo_maxpool_match = False
        if hasattr(FLAGS, 'wo_maxpool_match'):
            wo_maxpool_match = FLAGS.wo_maxpool_match

        wo_attentive_match = False
        if hasattr(FLAGS, 'wo_attentive_match'):
            wo_attentive_match = FLAGS.wo_attentive_match

        wo_max_attentive_match = False
        if hasattr(FLAGS, 'wo_max_attentive_match'):
            wo_max_attentive_match = FLAGS.wo_max_attentive_match

        # load vocabs
        print('Loading vocabs.')
        word_vocab = Vocab(word_vec_path, fileformat='txt3')
        label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
        num_classes = label_vocab.size()

        POS_vocab = None
        NER_vocab = None
        char_vocab = None
        if with_POS:
            POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        if with_NER:
            NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

        print('Build SentenceMatchDataStream ... ')
        testDataStream = SentenceMatchTrainer.SentenceMatchDataStream(
            in_path,
            word_vocab=word_vocab,
            char_vocab=char_vocab,
            POS_vocab=POS_vocab,
            NER_vocab=NER_vocab,
            label_vocab=label_vocab,
            batch_size=FLAGS.batch_size,
            isShuffle=False,
            isLoop=True,
            isSort=True,
            max_char_per_word=FLAGS.max_char_per_word,
            max_sent_length=FLAGS.max_sent_length)
        print('Number of instances in testDataStream: {}'.format(
            testDataStream.get_num_instance()))
        print('Number of batches in testDataStream: {}'.format(
            testDataStream.get_num_batch()))

        if wo_char: char_vocab = None

        init_scale = 0.01
        best_path = model_prefix + ".best.model"
        print('Decoding on the test set:')
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(
                -init_scale, init_scale)
            with tf.variable_scope("Model",
                                   reuse=False,
                                   initializer=initializer):
                valid_graph = SentenceMatchModelGraph(
                    num_classes,
                    word_vocab=word_vocab,
                    char_vocab=char_vocab,
                    POS_vocab=POS_vocab,
                    NER_vocab=NER_vocab,
                    dropout_rate=FLAGS.dropout_rate,
                    learning_rate=FLAGS.learning_rate,
                    optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2,
                    char_lstm_dim=FLAGS.char_lstm_dim,
                    context_lstm_dim=FLAGS.context_lstm_dim,
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                    is_training=False,
                    MP_dim=FLAGS.MP_dim,
                    context_layer_num=FLAGS.context_layer_num,
                    aggregation_layer_num=FLAGS.aggregation_layer_num,
                    fix_word_vec=FLAGS.fix_word_vec,
                    with_filter_layer=FLAGS.with_filter_layer,
                    with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway,
                    with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,
                    with_lex_decomposition=FLAGS.with_lex_decomposition,
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_char=(not FLAGS.wo_char),
                    with_left_match=(not FLAGS.wo_left_match),
                    with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match),
                    with_maxpool_match=(not FLAGS.wo_maxpool_match),
                    with_attentive_match=(not FLAGS.wo_attentive_match),
                    with_max_attentive_match=(
                        not FLAGS.wo_max_attentive_match))

            # remove word _embedding
            vars_ = {}
            for var in tf.all_variables():
                if "word_embedding" in var.name: continue
                if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            step = 0
            best_path = best_path.replace('//', '/')
            saver.restore(sess, best_path)

            accuracy = SentenceMatchTrainer.evaluate(testDataStream,
                                                     valid_graph,
                                                     sess,
                                                     outpath=out_path,
                                                     label_vocab=label_vocab,
                                                     mode=mode,
                                                     char_vocab=char_vocab,
                                                     POS_vocab=POS_vocab,
                                                     NER_vocab=NER_vocab)