def main(FLAGS): # np.random.seed(FLAGS.seed) train_path = FLAGS.train_path train_feat_path = FLAGS.train_feat_path dev_path = FLAGS.dev_path dev_feat_path = FLAGS.dev_feat_path word_vec_path = FLAGS.word_vec_path word_vec_path2 = FLAGS.word_vec_path2 char_vec_path = FLAGS.char_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') # word_vocab2 = Vocab(word_vec_path2, fileformat='txt3') # word_vocab = Vocab(word_vec_path, word_vec_path2, fileformat='txt4') char_vocab = None best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False if os.path.exists(best_path + ".index"): has_pre_trained_model = True print('Loading vocabs from a pre-trained model ...') label_vocab = Vocab(label_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2') else: print('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: print('Number of chars: {}'.format(len(all_chars))) if char_vec_path == "": char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) else: char_vocab = Vocab(char_vec_path, fileformat='txt3') char_vocab.dump_to_txt2(char_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: print('char_vocab shape is {}'.format(char_vocab.word_vecs.shape)) num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes)) sys.stdout.flush() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, train_feat_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = SentenceMatchDataStream(dev_path, dev_feat_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): # tf.set_random_seed(FLAGS.seed) initializer = tf.random_uniform_initializer(-init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() initializer_local = tf.local_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer, feed_dict={ train_graph.w_embedding: word_vocab.word_vecs, train_graph.c_embedding: char_vocab.word_vecs }) # sess.run(initializer, feed_dict={train_graph.w_embedding: word_vocab.word_vecs, train_graph.w_embedding_trainable: word_vocab2.word_vecs, # train_graph.c_embedding: char_vocab.word_vecs}) sess.run(initializer_local) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, FLAGS, best_path, label_vocab)
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if FLAGS.train == "sick": train_path = FLAGS.SICK_train_path dev_path = FLAGS.SICK_dev_path if FLAGS.test == "sick": test_path = FLAGS.SICK_test_path if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs parser,image_feats = None, None if FLAGS.with_dep: parser=Parser('snli') if FLAGS.with_image: image_feats=ImageFeatures() word_vocab = Vocab(word_vec_path, fileformat='txt3', parser=parser, beginning=FLAGS.beginning) #fileformat='txt3' best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" DEP_path = path_prefix + ".DEP_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None DEP_vocab = None print('has pretrained model: ', os.path.exists(best_path)) print('best_path: ' + best_path) if os.path.exists(best_path + '.meta'): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim, beginning=FLAGS.beginning) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build DataStream ... ') print('Reading trainDataStream') if not FLAGS.decoding_only: trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('Reading devDataStream') devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('Reading testDataStream') testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('save cache file') #word_vocab.parser.save_cache() #image_feats.save_feat() if not FLAGS.decoding_only: print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 if not FLAGS.decoding_only: with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, with_img_attentive_match=FLAGS.with_img_attentive_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) #, feed_dict={valid_graph.emb_init: word_vocab.word_vecs, train_graph.emb_init: word_vocab.word_vecs}) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") #if best_path.startswith('bimpm_baseline'): #best_path = best_path + '_sick' print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, dependency1_batch, dependency2_batch, dep_con1_batch, dep_con2_batch, img_feats_batch, img_id_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, #train_graph.get_emb_init(): word_vocab.word_vecs, #train_graph.get_in_question_dependency(): dependency1_batch, #train_graph.get_in_passage_dependency(): dependency2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if FLAGS.with_dep: feed_dict[train_graph.get_in_question_dependency()] = dependency1_batch feed_dict[train_graph.get_in_passage_dependency()] = dependency2_batch feed_dict[train_graph.get_in_question_dep_con()] = dep_con1_batch feed_dict[train_graph.get_in_passage_dep_con()] = dep_con2_batch if FLAGS.with_image: feed_dict[train_graph.get_image_feats()] = img_feats_batch if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, word_vocab=word_vocab) print("Current accuracy on dev is %.2f" % accuracy) #accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab) #print("Current accuracy on train is %.2f" % accuracy_train) if accuracy>best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer())#, feed_dict={valid_graph.emb_init: word_vocab.word_vecs}) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess, outpath=FLAGS.suffix+ FLAGS.train + FLAGS.test + ".result",char_vocab=char_vocab,label_vocab=label_vocab, word_vocab=word_vocab) print("Accuracy for test set is %.2f" % accuracy) accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab,word_vocab=word_vocab) print("Accuracy for train set is %.2f" % accuracy_train)
args, unparsed = parser.parse_known_args() # load the configuration file print('Loading configurations.') options = namespace_utils.load_namespace(args.model_prefix + ".config.json") if args.word_vec_path is None: args.word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(args.word_vec_path, fileformat='txt3') label_vocab = Vocab(args.model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() if options.with_char: char_vocab = Vocab(args.model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchDataStream(args.in_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=options)
def get_test_result(in_p,root_path): print('Loading configurations.') model_prefix =root_path+"/stsapp/src/logs/SentenceMatch.snli" word_vec_path = root_path+"/stsapp/src/data/snli/wordvec.txt" in_path = in_p out_path =root_path+"/stsapp/src/result.txt" print("access decoder") options = namespace_utils.load_namespace(model_prefix + ".config.json") if word_vec_path is None: word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3') label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() if options.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchDataStream(in_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=options) print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() best_path = model_prefix + ".best.model" init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=options) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") acc,result = train.evaluation(sess, valid_graph, testDataStream, outpath=out_path, label_vocab=label_vocab) print("Accuracy for test set is : ",colored(acc, 'green'),"\n") # print(result['probs']) return acc,result
def main(_): #for x in range (100): # Generate_random_initialization() # print (FLAGS.is_aggregation_lstm, FLAGS.context_lstm_dim, FLAGS.context_layer_num, FLAGS. aggregation_lstm_dim, FLAGS.aggregation_layer_num, FLAGS.max_window_size, FLAGS.MP_dim) print('Configurations:') #print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None if os.path.exists(best_path): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None output_res_index = 1 while True: Generate_random_initialization() st_cuda = '' if FLAGS.is_server == True: st_cuda = str(os.environ['CUDA_VISIBLE_DEVICES']) + '.' output_res_file = open('../result/' + st_cuda + str(output_res_index), 'wt') output_res_index += 1 output_res_file.write(str(FLAGS) + '\n\n') stt = str (FLAGS) best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_bilinear_att=(FLAGS.attention_type) , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, with_aggregation_attention=not FLAGS.wo_agg_self_att, is_answer_selection= FLAGS.is_answer_selection, is_shared_attention=FLAGS.is_shared_attention, modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm , max_window_size=FLAGS.max_window_size , prediction_mode=FLAGS.prediction_mode, context_lstm_dropout=not FLAGS.wo_lstm_drop_out, is_aggregation_siamese=FLAGS.is_aggregation_siamese , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_bilinear_att=(FLAGS.attention_type) , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, with_aggregation_attention=not FLAGS.wo_agg_self_att, is_answer_selection= FLAGS.is_answer_selection, is_shared_attention=FLAGS.is_shared_attention, modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm, max_window_size=FLAGS.max_window_size , prediction_mode=FLAGS.prediction_mode, context_lstm_dropout=not FLAGS.wo_lstm_drop_out, is_aggregation_siamese=FLAGS.is_aggregation_siamese , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention) initializer = tf.global_variables_initializer() vars_ = {} #for var in tf.all_variables(): for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) with tf.Session() as sess: sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch, batch_index = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch if FLAGS.is_answer_selection == True: feed_dict[train_graph.get_question_count()] = trainDataStream.question_count(batch_index) feed_dict[train_graph.get_answer_count()] = trainDataStream.answer_count(batch_index) _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if FLAGS.is_answer_selection == True and FLAGS.is_server == False: print ("q: {} a: {} loss_value: {}".format(trainDataStream.question_count(batch_index) ,trainDataStream.answer_count(batch_index), loss_value)) if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: #print(total_loss) # Print status to stdout. duration = time.time() - start_time start_time = time.time() output_res_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) total_loss = 0.0 #Evaluate against the validation set. output_res_file.write('valid- ') my_map, my_mrr = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr)) #print ("dev map: {}".format(my_map)) #print("Current accuracy is %.2f" % accuracy) #accuracy = my_map #if accuracy>best_accuracy: # best_accuracy = accuracy # saver.save(sess, best_path) # Evaluate against the test set. output_res_file.write ('test- ') my_map, my_mrr = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}\n\n".format(my_map, my_mrr)) if FLAGS.is_server == False: print ("test map: {}".format(my_map)) #Evaluate against the train set only for final epoch. if (step + 1) == max_steps: output_res_file.write ('train- ') my_map, my_mrr = evaluate(trainDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr)) # print("Best accuracy on dev set is %.2f" % best_accuracy) # # decoding # print('Decoding on the test set:') # init_scale = 0.01 # with tf.Graph().as_default(): # initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.variable_scope("Model", reuse=False, initializer=initializer): # valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, # dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, # lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, # aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, # context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, # fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, # word_level_MP_dim=FLAGS.word_level_MP_dim, # with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, # highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, # lex_decompsition_dim=FLAGS.lex_decompsition_dim, # with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), # with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), # with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), # with_bilinear_att=(not FLAGS.wo_bilinear_att) # , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, # with_aggregation_attention=not FLAGS.wo_agg_self_att, # is_answer_selection= FLAGS.is_answer_selection, # is_shared_attention=FLAGS.is_shared_attention, # modify_loss=FLAGS.modify_loss,is_aggregation_lstm=FLAGS.is_aggregation_lstm, # max_window_size=FLAGS.max_window_size, # prediction_mode=FLAGS.prediction_mode, # context_lstm_dropout=not FLAGS.wo_lstm_drop_out, # is_aggregation_siamese=FLAGS.is_aggregation_siamese) # # vars_ = {} # for var in tf.global_variables(): # if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue # vars_[var.name.split(":")[0]] = var # saver = tf.train.Saver(vars_) # # sess = tf.Session() # sess.run(tf.global_variables_initializer()) # step = 0 # saver.restore(sess, best_path) # # accuracy, mrr = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab # , mode='trec') # output_res_file.write("map for test set is %.2f\n" % accuracy) output_res_file.close()
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None if os.path.exists(best_path): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab) print("Current accuracy is %.2f" % accuracy) if accuracy>best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab) print("Accuracy for test set is %.2f" % accuracy)
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir tolower = FLAGS.use_lower_letter FLAGS.rl_matches = json.loads(FLAGS.rl_matches) # if not os.path.exists(log_dir): # os.makedirs(log_dir) path_prefix = log_dir + "/TriMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=tolower) best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None print('best path:', best_path) if os.path.exists(best_path + '.data-00000-of-00001') and not (FLAGS.create_new_model): print('Using pretrained model') has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2', tolower=tolower) char_vocab = Vocab(char_path, fileformat='txt2', tolower=tolower) if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2', tolower=tolower) if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2', tolower=tolower) else: print('Creating new model') print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER, tolower=tolower) if FLAGS.use_options: all_labels = ['0', '1'] print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) # for word in all_labels: # print('label',word) # input('check') label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2, tolower=tolower) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim, tolower=tolower) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs, dim=FLAGS.POS_dim, tolower=tolower) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs, dim=FLAGS.NER_dim, tolower=tolower) NER_vocab.dump_to_txt2(NER_path) print('all_labels:', label_vocab) print('has pretrained model:', has_pre_trained_model) # for word in word_vocab.word_vecs: print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build TriMatchDataStream ... ') gen_concat_mat = False gen_split_mat = False if FLAGS.matching_option == 7: gen_concat_mat = True if FLAGS.concat_context: gen_split_mat = True trainDataStream = TriMatchDataStream( train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) devDataStream = TriMatchDataStream( dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) testDataStream = TriMatchDataStream( test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, verbose=FLAGS.verbose, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) tf.summary.scalar("Training Loss", train_graph.get_loss() ) # Add a scalar summary for the snapshot loss. if FLAGS.verbose: valid_graph = train_graph else: # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=( not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): # print(var.name,var.get_shape().as_list()) if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) # input('check') config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() sub_loss_counter = 0.0 for step in range(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, sent3_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, word_idx_3_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, char_matrix_idx_3_batch, sent1_length_batch, sent2_length_batch, sent3_length_batch, sent1_char_length_batch, sent2_char_length_batch, sent3_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, concat_mat_batch, split_mat_batch_q, split_mat_batch_c) = cur_batch # print(label_id_batch) if FLAGS.verbose: print(label_id_batch) print(sent1_length_batch) print(sent2_length_batch) print(sent3_length_batch) # print(word_idx_1_batch) # print(word_idx_2_batch) # print(word_idx_3_batch) # print(sent1_batch) # print(sent2_batch) # print(sent3_batch) print(concat_mat_batch) input('check') feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_passage_lengths(): sent1_length_batch, train_graph.get_question_lengths(): sent2_length_batch, train_graph.get_choice_lengths(): sent3_length_batch, train_graph.get_in_passage_words(): word_idx_1_batch, train_graph.get_in_question_words(): word_idx_2_batch, train_graph.get_in_choice_words(): word_idx_3_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_passage_char_lengths( )] = sent1_char_length_batch feed_dict[train_graph.get_question_char_lengths( )] = sent2_char_length_batch feed_dict[train_graph.get_choice_char_lengths( )] = sent3_char_length_batch feed_dict[train_graph.get_in_passage_chars( )] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_question_chars( )] = char_matrix_idx_2_batch feed_dict[train_graph.get_in_choice_chars( )] = char_matrix_idx_3_batch if POS_vocab is not None: feed_dict[train_graph.get_in_passage_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_question_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_passage_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_question_ners()] = NER_idx_2_batch if concat_mat_batch is not None: feed_dict[train_graph.concat_idx_mat] = concat_mat_batch if split_mat_batch_q is not None: feed_dict[train_graph.split_idx_mat_q] = split_mat_batch_q feed_dict[train_graph.split_idx_mat_c] = split_mat_batch_c if FLAGS.verbose: return_list = sess.run([ train_graph.get_train_op(), train_graph.get_loss(), train_graph.get_predictions(), train_graph.get_prob(), train_graph.all_probs, train_graph.correct ] + train_graph.matching_vectors, feed_dict=feed_dict) print(len(return_list)) _, loss_value, pred, prob, all_probs, correct = return_list[ 0:6] print('pred=', pred) print('prob=', prob) print('logits=', all_probs) print('correct=', correct) for val in return_list[6:]: if isinstance(val, list): print('list len ', len(val)) for objj in val: print('this shape=', val.shape) print('this shape=', val.shape) # print(val) input('check') else: _, loss_value = sess.run( [train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value sub_loss_counter += loss_value if step % int(FLAGS.display_every) == 0: print('{},{} '.format(step, sub_loss_counter), end="") sys.stdout.flush() sub_loss_counter = 0.0 # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') if FLAGS.predict_val: outpath = path_prefix + '.iter%d' % (step) + '.probs' else: outpath = None accuracy = evaluate(devDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options, outpath=outpath, mode='prob') print("Current accuracy on dev set is %.2f" % accuracy) if accuracy >= best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print('saving the current model.') accuracy = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options, outpath=outpath, mode='prob') print("Current accuracy on test set is %.2f" % accuracy) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options) print("Accuracy for test set is %.2f" % accuracy)
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir result_dir = '../result' if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(result_dir): os.makedirs(result_dir) path_prefix = log_dir + "/Han.{}".format(FLAGS.suffix) save_namespace(FLAGS, path_prefix + ".config.json") word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' label_path = path_prefix + ".label_vocab" print('Collect words and labels ...') (all_words, all_labels) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build HanDataStream ... ') trainDataStream = HanDataStream(inpath=train_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, max_sent_length=FLAGS.max_sent_length) devDataStream = HanDataStream(inpath=dev_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, max_sent_length=FLAGS.max_sent_length) testDataStream = HanDataStream(inpath=test_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=True, batch_size = FLAGS.batch_size) tf.summary.scalar("Training Loss", train_graph.loss) # Add a scalar summary for the snapshot loss. print("Train Graph Build") with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=False, batch_size = 1) print ("dev Graph Build") initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) output_res_file = open(result_dir + '/' + FLAGS.suffix, 'wt') output_res_file.write(str(FLAGS)) with tf.Session() as sess: sess.run(initializer) train_size = trainDataStream.get_num_instance() max_steps = (train_size * FLAGS.max_epochs) // FLAGS.batch_size epoch_size = max_steps // (FLAGS.max_epochs) # + 1 total_loss = 0.0 start_time = time.time() best_accuracy = 0 for step in range(max_steps): # read data # _truth = [] # _sents_length = [] # _in_text_words = [] # for i in range(FLAGS.batch_size): # cur_instance, instance_index = trainDataStream.nextInstance () # (label,text,label_id, word_idx, sents_length) = cur_instance # # _truth.append(label_id) # _sents_length.append(sents_length) # _in_text_words.append(word_idx) # # feed_dict = { # train_graph.truth: np.array(_truth), # train_graph.sents_length: tuple(_sents_length), # train_graph.in_text_words: tuple(_in_text_words), # } feed_dict = get_feed_dict(data_stream=trainDataStream, graph=train_graph, batch_size=FLAGS.batch_size, is_testing=False) _, loss_value, _score = sess.run([train_graph.train_op, train_graph.loss , train_graph.batch_class_scores], feed_dict=feed_dict) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() if (step + 1) % epoch_size == 0 or (step + 1) == max_steps: # print(total_loss) duration = time.time() - start_time start_time = time.time() print(duration, step, "Loss: ", total_loss) output_res_file.write('\nStep %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. output_res_file.write('valid- ') dev_accuracy = evaluate(devDataStream, valid_graph, sess) output_res_file.write("%.2f\n" % dev_accuracy) print("Current dev accuracy is %.2f" % dev_accuracy) if dev_accuracy > best_accuracy: best_accuracy = dev_accuracy saver.save(sess, best_path) output_res_file.write('test- ') test_accuracy = evaluate(testDataStream, valid_graph, sess) print("Current test accuracy is %.2f" % test_accuracy) output_res_file.write("%.2f\n" % test_accuracy) output_res_file.close() sys.stdout.flush()
def decode(self, model_prefix, in_path, out_path, word_vec_path, mode, out_json_path=None, dump_prob_path=None): # model_prefix = args.model_prefix # in_path = args.in_path # out_path = args.out_path # word_vec_path = args.word_vec_path # mode = args.mode # out_json_path = None # dump_prob_path = None # load the configuration file print('Loading configurations.') FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") print(FLAGS) with_POS = False if hasattr(FLAGS, 'with_POS'): with_POS = FLAGS.with_POS with_NER = False if hasattr(FLAGS, 'with_NER'): with_NER = FLAGS.with_NER wo_char = False if hasattr(FLAGS, 'wo_char'): wo_char = FLAGS.wo_char wo_left_match = False if hasattr(FLAGS, 'wo_left_match'): wo_left_match = FLAGS.wo_left_match wo_right_match = False if hasattr(FLAGS, 'wo_right_match'): wo_right_match = FLAGS.wo_right_match wo_full_match = False if hasattr(FLAGS, 'wo_full_match'): wo_full_match = FLAGS.wo_full_match wo_maxpool_match = False if hasattr(FLAGS, 'wo_maxpool_match'): wo_maxpool_match = FLAGS.wo_maxpool_match wo_attentive_match = False if hasattr(FLAGS, 'wo_attentive_match'): wo_attentive_match = FLAGS.wo_attentive_match wo_max_attentive_match = False if hasattr(FLAGS, 'wo_max_attentive_match'): wo_max_attentive_match = FLAGS.wo_max_attentive_match # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3') label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() POS_vocab = None NER_vocab = None char_vocab = None if with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') if with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchTrainer.SentenceMatchDataStream( in_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) if wo_char: char_vocab = None init_scale = 0.01 best_path = model_prefix + ".best.model" print('Decoding on the test set:') with tf.Graph().as_default(): initializer = tf.random_uniform_initializer( -init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_char=(not FLAGS.wo_char), with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=( not FLAGS.wo_max_attentive_match)) # remove word _embedding vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 best_path = best_path.replace('//', '/') saver.restore(sess, best_path) accuracy = SentenceMatchTrainer.evaluate(testDataStream, valid_graph, sess, outpath=out_path, label_vocab=label_vocab, mode=mode, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
def main(FLAGS): train_path = FLAGS.train_path dev_path = FLAGS.dev_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False char_vocab = None if os.path.exists(best_path + ".index"): has_pre_trained_model = True logger.info('Loading vocabs from a pre-trained model ...') label_vocab = Vocab(label_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2') else: logger.info('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) logger.info('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: logger.info('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) logger.info('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) num_classes = label_vocab.size() logger.info("Number of labels: {}".format(num_classes)) sys.stdout.flush() logger.info('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) logger.info('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) logger.info('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=FLAGS) logger.info('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) logger.info('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() # 初始化写日志的wirter, 并将当前TensorFlow计算图写入日志 train_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph) # valid_writer = tf.summary.FileWriter(SUMMARY_DIR + '/valid') sess.run(initializer) if has_pre_trained_model: logger.info("Restoring model from " + best_path) saver.restore(sess, best_path) logger.info("DONE!") # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, FLAGS, best_path, train_writer, label_vocab) train_writer.close()
def main(FLAGS): train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path dev_path_target = FLAGS.dev_path_target test_path_target = FLAGS.test_path_target word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) os.makedirs(os.path.join(log_dir, '../result_source')) os.makedirs(os.path.join(log_dir, '../logits_source')) os.makedirs(os.path.join(log_dir, '../result_target')) os.makedirs(os.path.join(log_dir, '../logits_target')) log_dir_target = FLAGS.model_dir + '_target' if not os.path.exists(log_dir_target): os.makedirs(log_dir_target) path_prefix = log_dir + "/ESIM.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") path_prefix_target = log_dir_target + "/ESIM.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix_target + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' best_path_target = path_prefix_target + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False char_vocab = None # if os.path.exists(best_path + ".index"): print('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes)) sys.stdout.flush() print('Build SentenceMatchDataStream ... ') trainDataStream = DataStream(train_path, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = DataStream(dev_path, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() testDataStream = DataStream(test_path, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() devDataStream_target = DataStream(dev_path_target, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() testDataStream_target = DataStream(test_path_target, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = Model(num_classes, word_vocab=word_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = Model(num_classes, word_vocab=word_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1) config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(initializer) # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, testDataStream, devDataStream_target, testDataStream_target, FLAGS, best_path, best_path_target)
def main(FLAGS): tf.logging.set_verbosity(tf.logging.INFO) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path kg_path = FLAGS.kg_path wordnet_path = FLAGS.wordnet_path lemma_vec_path = FLAGS.lemma_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) os.makedirs(os.path.join(log_dir, '../result')) os.makedirs(os.path.join(log_dir, '../logits')) path_prefix = log_dir + "/KEIM.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') lemma_vocab = Vocab(lemma_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" char_vocab = None tf.logging.info('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) tf.logging.info('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: tf.logging.info('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) tf.logging.info('word_vocab shape is {}'.format( word_vocab.word_vecs.shape)) tf.logging.info('lemma_word_vocab shape is {}'.format( lemma_vocab.word_vecs.shape)) num_classes = label_vocab.size() tf.logging.info("Number of labels: {}".format(num_classes)) sys.stdout.flush() with open(wordnet_path, 'rb') as f: wordnet_vocab = pkl.load(f) tf.logging.info('wordnet_vocab shape is {}'.format(len(wordnet_vocab))) with open(kg_path, 'rb') as f: kg = pkl.load(f) tf.logging.info('kg shape is {}'.format(len(kg))) tf.logging.info('Build SentenceMatchDataStream ... ') trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=None, kg=kg, wordnet_vocab=wordnet_vocab, lemma_vocab=lemma_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) tf.logging.info('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) tf.logging.info('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=None, kg=kg, wordnet_vocab=wordnet_vocab, lemma_vocab=lemma_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) tf.logging.info('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) tf.logging.info('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=None, kg=kg, wordnet_vocab=wordnet_vocab, lemma_vocab=lemma_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) tf.logging.info('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) tf.logging.info('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() # initializer = tf.truncated_normal_initializer(stddev=0.02) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = Model(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, lemma_vocab=lemma_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = Model(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, lemma_vocab=lemma_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): # if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1) config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(initializer) # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, testDataStream, FLAGS, best_path)