def load_model (model_prefix, word_vocab, batch_size): FLAGS = load_namespace(model_prefix + ".config.json") label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') num_classes = label_vocab.size() best_path = model_prefix + ".best.model" with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=False, batch_size=batch_size) vars_ = {} print ("ValidGraph Build") for var in tf.global_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) config = tf.ConfigProto(intra_op_parallelism_threads=0, inter_op_parallelism_threads=0, allow_soft_placement=True) sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver.restore(sess, best_path) return valid_graph, sess, label_vocab, FLAGS
if not options.__dict__.has_key("with_target_lattice"): options.__dict__["with_target_lattice"] = False if not options.__dict__.has_key("add_first_word_prob_for_phrase"): options.__dict__["add_first_word_prob_for_phrase"] = False if not options.__dict__.has_key("pretrain_with_max_matching"): options.__dict__["pretrain_with_max_matching"] = False if not options.__dict__.has_key("reward_type"): options.__dict__["reward_type"] = "bleu" return options if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, help='Configuration file.') print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) FLAGS, unparsed = parser.parse_known_args() if FLAGS.config_path is not None: print('Loading the configuration from ' + FLAGS.config_path) FLAGS = namespace_utils.load_namespace(FLAGS.config_path) FLAGS = enrich_options(FLAGS) sys.stdout.flush() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
from __future__ import print_function import argparse import re import os import sys import json import time import numpy as np import codecs import MHQA_data_stream import namespace_utils if __name__ == '__main__': # load the configuration file FLAGS = namespace_utils.load_namespace("/u/nalln478/ws/exp.multihop_qa/sub.MHQA/config.json") in_path = "/u/nalln478/ws/exp.multihop_qa/sub.MHQA/data/dev.json" print('Loading test set from {}.'.format(in_path)) testset, _ = MHQA_data_stream.read_data_file(in_path, FLAGS) print('Number of samples: {}'.format(len(testset))) right = 0.0 total = 0.0 cands_total = 0.0 for i, (question, passage, entity_start, entity_end, edges, candidates, ref, ids, candidates_str) in enumerate(testset): if np.argmax(len(x) for x in candidates) == ref: right += 1.0 total += 1.0 cands_total += len(candidates)
default="prediction", help='prediction or probs') args, unparsed = parser.parse_known_args() model_prefix = args.model_prefix in_path = args.in_path out_path = args.out_path word_vec_path = args.word_vec_path mode = args.mode out_json_path = None dump_prob_path = None # load the configuration file print('Loading configurations.') FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") print(FLAGS) with_POS = False if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS with_NER = False if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER wo_char = False if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char wo_left_match = False if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match wo_right_match = False if hasattr(FLAGS, 'wo_right_match'): wo_right_match = FLAGS.wo_right_match
from SentenceMatchModelGraph import SentenceMatchModelGraph from SentenceMatchDataStream import SentenceMatchDataStream if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.') parser.add_argument('--in_path', type=str, default='../data_quora/dev.tsv', help='the path to the test file.') parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.') parser.add_argument('--word_vec_path', type=str, default='../data_quora/wordvec.txt', help='word embedding file for the input file.') args, unparsed = parser.parse_known_args() # load the configuration file print('Loading configurations.') options = namespace_utils.load_namespace(args.model_prefix + ".config.json") if args.word_vec_path is None: args.word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(args.word_vec_path, fileformat='txt3') label_vocab = Vocab(args.model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() char_vocab = None if options.with_char: char_vocab = Vocab(args.model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
import tensorflow as tf import os from vocab_utils import Vocab as EntVocab import namespace_utils from SentenceMatchModelGraph import SentenceMatchModelGraph if __name__ == "__main__": config_path = "./configs/snli.sample.config" config_FLAGS = namespace_utils.load_namespace(config_path) config_FLAGS.__dict__["in_format"] = 'tsv' word_vec_path = config_FLAGS.word_vec_path log_dir = config_FLAGS.model_dir path_prefix = os.path.join(log_dir, "SentenceMatch.{}".format(config_FLAGS.suffix)) ent_word_vocab = EntVocab(word_vec_path, fileformat='txt3') print("word_vocab shape is {}".format(ent_word_vocab.word_vecs.shape)) best_path = path_prefix + ".best.model" label_path = path_prefix + ".label_vocab" print("best_path: {}".format(best_path)) if os.path.exists(best_path + ".index"): print("Loading label vocab") label_vocab = EntVocab(label_path, fileformat='txt2') else: raise Exception("no pretrained model") num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes))
'The type of LSTM out, mean pooled of all steps or select the last step.' ) parser.add_argument('--pos_weight', type=float, default=0.0, help='Pos weight of weighted cross entropy losses') parser.add_argument('--seed', type=int, default=2018, help='Initial seed for algorithm.') parser.add_argument('--config_path', type=str, help='Configuration file.') # print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) args, unparsed = parser.parse_known_args() if args.config_path is not None: print('Loading the configuration from ' + args.config_path) FLAGS = namespace_utils.load_namespace(args.config_path) else: FLAGS = args sys.stdout.flush() # enrich arguments to backwards compatibility FLAGS = enrich_options(FLAGS) if not FLAGS.with_cv: main(FLAGS) else: main_cv(FLAGS) # python SentenceMatchTrainer.py --config_path ../configs/snli.test.config
op = 'wik' if args.is_trec == True or args.is_trec == 'True': op = 'tre' log_dir = '../models' + op path_prefix = log_dir + "/SentenceMatch.normal" #model_prefix = args.model_prefix in_path = args.in_path word_vec_path = args.word_vec_path out_json_path = None dump_prob_path = None # load the configuration file print('Loading configurations.') FLAGS = namespace_utils.load_namespace(path_prefix + args.index + ".config.json") print(FLAGS) with_POS = False if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS with_NER = False if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER wo_char = False if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char wo_left_match = False if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match wo_right_match = False if hasattr(FLAGS, 'wo_right_match'): wo_right_match = FLAGS.wo_right_match
def get_test_result(in_p,root_path): print('Loading configurations.') model_prefix =root_path+"/stsapp/src/logs/SentenceMatch.snli" word_vec_path = root_path+"/stsapp/src/data/snli/wordvec.txt" in_path = in_p out_path =root_path+"/stsapp/src/result.txt" print("access decoder") options = namespace_utils.load_namespace(model_prefix + ".config.json") if word_vec_path is None: word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3') label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() if options.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchDataStream(in_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=options) print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() best_path = model_prefix + ".best.model" init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=options) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") acc,result = train.evaluation(sess, valid_graph, testDataStream, outpath=out_path, label_vocab=label_vocab) print("Accuracy for test set is : ",colored(acc, 'green'),"\n") # print(result['probs']) return acc,result
assert self.out_neigh_mask.shape == self.out_neigh_indices.shape assert self.out_neigh_mask.shape == self.out_neigh_edges.shape # [batch_size, sent_len_max] self.src = padding_utils.pad_2d_vals(ori_batch.src, len(ori_batch.src), self.options.max_src_len) self.sent_inp = padding_utils.pad_2d_vals(ori_batch.sent_inp, len(ori_batch.sent_inp), self.options.max_answer_len) self.sent_out = padding_utils.pad_2d_vals(ori_batch.sent_out, len(ori_batch.sent_out), self.options.max_answer_len) if __name__ == "__main__": FLAGS = namespace_utils.load_namespace('../config.json') print('Collecting vocab') allEdgelabels = set([line.strip().split()[0] \ for line in open('../data/edgelabel_vocab.en', 'rU')]) edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') word_vocab_enc = Vocab('../data/vectors.en.st', fileformat='txt2') word_vocab_dec = Vocab('../data/vectors.de.st', fileformat='txt2') print('Loading trainset') trainset, _, _, _, _ = read_amr_file('../data/newstest2013.tok.json', FLAGS, word_vocab_enc, word_vocab_dec, None, edgelabel_vocab) print('Build DataStream ... ') trainDataStream = G2SDataStream(trainset, word_vocab_enc,
def question_gen_run(argv): #parser = argparse.ArgumentParser() #parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.') #parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.') #parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.') #parser.add_argument('--mode', type=str, required=True, help='Can be `greedy` or `beam`') #args, unparsed = parser.parse_known_args() #model_prefix = args.model_prefix #in_path = args.in_path #out_path = args.out_path #mode = args.mode print(sys.argv) model_prefix = argv[0] in_path = argv[1] out_path = argv[2] mode = argv[3] print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = char_vocab = POS_vocab = NER_vocab = None if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, _ = NP2P_data_stream.read_generation_datasets_from_fof( in_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': testset, _ = NP2P_data_stream.read_all_GenerationDatasets( in_path, isLower=FLAGS.isLower) else: testset, _ = NP2P_data_stream.read_all_GQA_questions( in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1 if mode.find('beam') >= 0: batch_size = 1 devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size) print('Number of instances in testDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( devDataStream.get_num_batch())) best_path = model_prefix + ".best.model" with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-0.01, 0.01) with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode="decode") ## remove word _embedding vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) initializer = tf.global_variables_initializer() #gpu_fraction = 0.1 #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction) #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) sess = tf.Session() sess.run(initializer) saver.restore(sess, best_path) # restore the model total = 0 correct = 0 if mode.endswith('evaluate'): ref_outfile = open(out_path + ".ref", 'wt') pred_outfile = open(out_path + ".pred", 'wt') else: outfile = open(out_path, 'wt') total_num = devDataStream.get_num_batch() devDataStream.reset() for i in range(total_num): cur_batch = devDataStream.get_batch(i) if mode == 'pointwise': (sentences, prediction_lengths, generator_input_idx, generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode=mode) for j in xrange(cur_batch.batch_size): cur_total = cur_batch.answer_lengths[j] cur_correct = 0 for k in xrange(cur_total): if generator_output_idx[ j, k] == cur_batch.in_answer_words[j, k]: cur_correct += 1.0 total += cur_total correct += cur_correct outfile.write( cur_batch.instances[j][1].tokText.encode('utf-8') + "\n") outfile.write(sentences[j].encode('utf-8') + "\n") outfile.write("========\n") outfile.flush() print('Current dev accuracy is %d/%d=%.2f' % (correct, total, correct / float(total) * 100)) elif mode in ['greedy', 'multinomial']: print('Batch {}'.format(i)) (sentences, prediction_lengths, generator_input_idx, generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode=mode) for j in xrange(cur_batch.batch_size): outfile.write( cur_batch.instances[j][1].ID_num.encode('utf-8') + "\n") outfile.write( cur_batch.instances[j][1].tokText.encode('utf-8') + "\n") outfile.write(sentences[j].encode('utf-8') + "\n") outfile.write("========\n") outfile.flush() elif mode == 'greedy_evaluate': print('Batch {}'.format(i)) (sentences, prediction_lengths, generator_input_idx, generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="greedy") for j in xrange(cur_batch.batch_size): ref_outfile.write( cur_batch.instances[j][1].tokText.encode('utf-8') + "\n") pred_outfile.write(sentences[j].encode('utf-8') + "\n") ref_outfile.flush() pred_outfile.flush() elif mode == 'beam_evaluate': print('Instance {}'.format(i)) ref_outfile.write( cur_batch.instances[0][1].tokText.encode('utf-8') + "\n") ref_outfile.flush() hyps = run_beam_search(sess, valid_graph, word_vocab, cur_batch, FLAGS) cur_passage = cur_batch.instances[0][0] cur_id2phrase = None if FLAGS.with_phrase_projection: (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0] cur_sent = hyps[0].idx_seq_to_string(cur_passage, cur_id2phrase, word_vocab, FLAGS) pred_outfile.write(cur_sent.encode('utf-8') + "\n") pred_outfile.flush() else: # beam search print('Instance {}'.format(i)) hyps = run_beam_search(sess, valid_graph, word_vocab, cur_batch, FLAGS) outfile.write( "Input: " + cur_batch.instances[0][0].tokText.encode('utf-8') + "\n") outfile.write( "Truth: " + cur_batch.instances[0][1].tokText.encode('utf-8') + "\n") for j in xrange(len(hyps)): hyp = hyps[j] cur_passage = cur_batch.instances[0][0] cur_id2phrase = None if FLAGS.with_phrase_projection: (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0] cur_sent = hyp.idx_seq_to_string(cur_passage, cur_id2phrase, word_vocab, FLAGS) outfile.write("Hyp-{}: ".format(j) + cur_sent.encode('utf-8') + " {}".format(hyp.avg_log_prob()) + "\n") #outfile.write("========\n") outfile.flush() if mode.endswith('evaluate'): ref_outfile.close() pred_outfile.close() else: outfile.close()
def decode(self, model_prefix, in_path, out_path, word_vec_path, mode, out_json_path=None, dump_prob_path=None): # model_prefix = args.model_prefix # in_path = args.in_path # out_path = args.out_path # word_vec_path = args.word_vec_path # mode = args.mode # out_json_path = None # dump_prob_path = None # load the configuration file print('Loading configurations.') FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") print(FLAGS) with_POS = False if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS with_NER = False if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER wo_char = False if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char wo_left_match = False if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match wo_right_match = False if hasattr(FLAGS, 'wo_right_match'): wo_right_match = FLAGS.wo_right_match wo_full_match = False if hasattr(FLAGS, 'wo_full_match'): wo_full_match = FLAGS.wo_full_match wo_maxpool_match = False if hasattr(FLAGS, 'wo_maxpool_match'): wo_maxpool_match = FLAGS.wo_maxpool_match wo_attentive_match = False if hasattr(FLAGS, 'wo_attentive_match'): wo_attentive_match = FLAGS.wo_attentive_match wo_max_attentive_match = False if hasattr(FLAGS, 'wo_max_attentive_match'): wo_max_attentive_match = FLAGS.wo_max_attentive_match # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3') label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() POS_vocab = None NER_vocab = None char_vocab = None if with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') if with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchTrainer.SentenceMatchDataStream( in_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) if wo_char: char_vocab = None init_scale = 0.01 best_path = model_prefix + ".best.model" print('Decoding on the test set:') with tf.Graph().as_default(): initializer = tf.random_uniform_initializer( -init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_char=(not FLAGS.wo_char), with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=( not FLAGS.wo_max_attentive_match)) # remove word _embedding vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 best_path = best_path.replace('//', '/') saver.restore(sess, best_path) accuracy = SentenceMatchTrainer.evaluate(testDataStream, valid_graph, sess, outpath=out_path, label_vocab=label_vocab, mode=mode, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)