def validate_and_save(model, devset_batches, log_file, best_accu): path_prefix = FLAGS.log_dir + "/MHQA.{}".format(FLAGS.suffix) start_time = time.time() res_dict = evaluate_dataset(model, devset_batches) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format( best_accu, dev_accu)) best_path = path_prefix + '.model.bin' torch.save(model.state_dict(), best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) print('-------------') log_file.write('-------------\n') return best_accu
def validate_and_save(sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu): best_path = path_prefix + ".best.model" start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format( best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() return best_accu
def main(_): print('Configurations:') print(FLAGS) # 打印各个参数 root_path = FLAGS.root_path train_path = root_path + FLAGS.train_path dev_path = root_path + FLAGS.dev_path test_path = root_path + FLAGS.test_path word_vec_path = root_path + FLAGS.word_vec_path model_dir = root_path + FLAGS.model_dir if tf.gfile.Exists(model_dir + '/mnist_with_summaries'): print("delete summaries") tf.gfile.DeleteRecursively(model_dir + '/mnist_with_summaries') if not os.path.exists(model_dir): os.makedirs(model_dir) path_prefix = model_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # 保存参数 best_path = path_prefix + '.best.model' label_path = path_prefix + ".label_vocab" has_pre_trained_model = False ckpt = tf.train.get_checkpoint_state(model_dir) if ckpt and ckpt.model_checkpoint_path: print("-------has_pre_trained_model--------") print(ckpt.model_checkpoint_path) has_pre_trained_model = True ############# build vocabs################# print('Collect words, chars and labels ...') (all_words, all_labels) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) word_vocab = Vocab(pattern='word') # 定义一个类 word_vocab.patternWord(word_vec_path, model_dir) label_vocab = Vocab(pattern="label") label_vocab.patternLabel(all_labels, label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = len(all_labels) if FLAGS.wo_char: char_vocab = None ##### Build SentenceMatchDataStream ################ print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream( train_path, word_vocab=word_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=False, max_sent_length=FLAGS.max_sent_length) devDataStream = SentenceMatchDataStream( dev_path, word_vocab=word_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=False, max_sent_length=FLAGS.max_sent_length) testDataStream = SentenceMatchDataStream( test_path, word_vocab=word_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=False, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() best_accuracy = 0.0 init_scale = 0.01 g_2 = tf.Graph() with g_2.as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph( num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, with_word=True, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) tf.summary.scalar("Training Loss", train_graph.get_loss()) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph( num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, with_word=True, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) initializer = tf.global_variables_initializer() saver = tf.train.Saver() # vars_ = {} # for var in tf.global_variables(): # if "word_embedding" in var.name: continue # vars_[var.name.split(":")[0]] = var # saver = tf.train.Saver(vars_) sess = tf.Session() merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter( model_dir + '/mnist_with_summaries/train', sess.graph) sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in iter(range(max_steps)): cur_batch = trainDataStream.nextBatch() (label_id_batch, word_idx_1_batch, word_idx_2_batch, sent1_length_batch, sent2_length_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, } # in_question_repres,in_ques=sess.run([train_graph.in_question_repres,train_graph.in_ques],feed_dict=feed_dict) # print(in_question_repres,in_ques) # break _, loss_value, summary = sess.run( [train_graph.get_train_op(), train_graph.get_loss(), merged], feed_dict=feed_dict) total_loss += loss_value if step % 5000 == 0: # train_writer.add_summary(summary, step) print("step:", step, "loss:", loss_value) if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess) print("Current accuracy is %.2f" % accuracy) if accuracy >= best_accuracy: print('Saving model since it\'s the best so far') best_accuracy = accuracy saver.save(sess, best_path) sys.stdout.flush() print("Best accuracy on dev set is %.2f" % best_accuracy)
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if FLAGS.train == "sick": train_path = FLAGS.SICK_train_path dev_path = FLAGS.SICK_dev_path if FLAGS.test == "sick": test_path = FLAGS.SICK_test_path if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs parser,image_feats = None, None if FLAGS.with_dep: parser=Parser('snli') if FLAGS.with_image: image_feats=ImageFeatures() word_vocab = Vocab(word_vec_path, fileformat='txt3', parser=parser, beginning=FLAGS.beginning) #fileformat='txt3' best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" DEP_path = path_prefix + ".DEP_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None DEP_vocab = None print('has pretrained model: ', os.path.exists(best_path)) print('best_path: ' + best_path) if os.path.exists(best_path + '.meta'): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim, beginning=FLAGS.beginning) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build DataStream ... ') print('Reading trainDataStream') if not FLAGS.decoding_only: trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('Reading devDataStream') devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('Reading testDataStream') testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('save cache file') #word_vocab.parser.save_cache() #image_feats.save_feat() if not FLAGS.decoding_only: print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 if not FLAGS.decoding_only: with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, with_img_attentive_match=FLAGS.with_img_attentive_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) #, feed_dict={valid_graph.emb_init: word_vocab.word_vecs, train_graph.emb_init: word_vocab.word_vecs}) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") #if best_path.startswith('bimpm_baseline'): #best_path = best_path + '_sick' print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, dependency1_batch, dependency2_batch, dep_con1_batch, dep_con2_batch, img_feats_batch, img_id_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, #train_graph.get_emb_init(): word_vocab.word_vecs, #train_graph.get_in_question_dependency(): dependency1_batch, #train_graph.get_in_passage_dependency(): dependency2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if FLAGS.with_dep: feed_dict[train_graph.get_in_question_dependency()] = dependency1_batch feed_dict[train_graph.get_in_passage_dependency()] = dependency2_batch feed_dict[train_graph.get_in_question_dep_con()] = dep_con1_batch feed_dict[train_graph.get_in_passage_dep_con()] = dep_con2_batch if FLAGS.with_image: feed_dict[train_graph.get_image_feats()] = img_feats_batch if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, word_vocab=word_vocab) print("Current accuracy on dev is %.2f" % accuracy) #accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab) #print("Current accuracy on train is %.2f" % accuracy_train) if accuracy>best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer())#, feed_dict={valid_graph.emb_init: word_vocab.word_vecs}) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess, outpath=FLAGS.suffix+ FLAGS.train + FLAGS.test + ".result",char_vocab=char_vocab,label_vocab=label_vocab, word_vocab=word_vocab) print("Accuracy for test set is %.2f" % accuracy) accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab,word_vocab=word_vocab) print("Accuracy for train set is %.2f" % accuracy_train)
def main(_): #for x in range (100): # Generate_random_initialization() # print (FLAGS.is_aggregation_lstm, FLAGS.context_lstm_dim, FLAGS.context_layer_num, FLAGS. aggregation_lstm_dim, FLAGS.aggregation_layer_num, FLAGS.max_window_size, FLAGS.MP_dim) print('Configurations:') #print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None if os.path.exists(best_path): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None output_res_index = 1 while True: Generate_random_initialization() st_cuda = '' if FLAGS.is_server == True: st_cuda = str(os.environ['CUDA_VISIBLE_DEVICES']) + '.' output_res_file = open('../result/' + st_cuda + str(output_res_index), 'wt') output_res_index += 1 output_res_file.write(str(FLAGS) + '\n\n') stt = str (FLAGS) best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_bilinear_att=(FLAGS.attention_type) , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, with_aggregation_attention=not FLAGS.wo_agg_self_att, is_answer_selection= FLAGS.is_answer_selection, is_shared_attention=FLAGS.is_shared_attention, modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm , max_window_size=FLAGS.max_window_size , prediction_mode=FLAGS.prediction_mode, context_lstm_dropout=not FLAGS.wo_lstm_drop_out, is_aggregation_siamese=FLAGS.is_aggregation_siamese , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_bilinear_att=(FLAGS.attention_type) , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, with_aggregation_attention=not FLAGS.wo_agg_self_att, is_answer_selection= FLAGS.is_answer_selection, is_shared_attention=FLAGS.is_shared_attention, modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm, max_window_size=FLAGS.max_window_size , prediction_mode=FLAGS.prediction_mode, context_lstm_dropout=not FLAGS.wo_lstm_drop_out, is_aggregation_siamese=FLAGS.is_aggregation_siamese , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention) initializer = tf.global_variables_initializer() vars_ = {} #for var in tf.all_variables(): for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) with tf.Session() as sess: sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch, batch_index = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch if FLAGS.is_answer_selection == True: feed_dict[train_graph.get_question_count()] = trainDataStream.question_count(batch_index) feed_dict[train_graph.get_answer_count()] = trainDataStream.answer_count(batch_index) _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if FLAGS.is_answer_selection == True and FLAGS.is_server == False: print ("q: {} a: {} loss_value: {}".format(trainDataStream.question_count(batch_index) ,trainDataStream.answer_count(batch_index), loss_value)) if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: #print(total_loss) # Print status to stdout. duration = time.time() - start_time start_time = time.time() output_res_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) total_loss = 0.0 #Evaluate against the validation set. output_res_file.write('valid- ') my_map, my_mrr = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr)) #print ("dev map: {}".format(my_map)) #print("Current accuracy is %.2f" % accuracy) #accuracy = my_map #if accuracy>best_accuracy: # best_accuracy = accuracy # saver.save(sess, best_path) # Evaluate against the test set. output_res_file.write ('test- ') my_map, my_mrr = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}\n\n".format(my_map, my_mrr)) if FLAGS.is_server == False: print ("test map: {}".format(my_map)) #Evaluate against the train set only for final epoch. if (step + 1) == max_steps: output_res_file.write ('train- ') my_map, my_mrr = evaluate(trainDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr)) # print("Best accuracy on dev set is %.2f" % best_accuracy) # # decoding # print('Decoding on the test set:') # init_scale = 0.01 # with tf.Graph().as_default(): # initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.variable_scope("Model", reuse=False, initializer=initializer): # valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, # dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, # lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, # aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, # context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, # fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, # word_level_MP_dim=FLAGS.word_level_MP_dim, # with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, # highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, # lex_decompsition_dim=FLAGS.lex_decompsition_dim, # with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), # with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), # with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), # with_bilinear_att=(not FLAGS.wo_bilinear_att) # , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, # with_aggregation_attention=not FLAGS.wo_agg_self_att, # is_answer_selection= FLAGS.is_answer_selection, # is_shared_attention=FLAGS.is_shared_attention, # modify_loss=FLAGS.modify_loss,is_aggregation_lstm=FLAGS.is_aggregation_lstm, # max_window_size=FLAGS.max_window_size, # prediction_mode=FLAGS.prediction_mode, # context_lstm_dropout=not FLAGS.wo_lstm_drop_out, # is_aggregation_siamese=FLAGS.is_aggregation_siamese) # # vars_ = {} # for var in tf.global_variables(): # if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue # vars_[var.name.split(":")[0]] = var # saver = tf.train.Saver(vars_) # # sess = tf.Session() # sess.run(tf.global_variables_initializer()) # step = 0 # saver.restore(sess, best_path) # # accuracy, mrr = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab # , mode='trec') # output_res_file.write("map for test set is %.2f\n" % accuracy) output_res_file.close()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_from_fof( FLAGS.train_path, FLAGS) else: trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_file( FLAGS.train_path, FLAGS) random.shuffle(trainset) devset = trainset[:200] trainset = trainset[200:] print('Number of training samples: {}'.format(len(trainset))) print('Number of dev samples: {}'.format(len(devset))) max_node = trn_node max_in_neigh = trn_in_neigh max_out_neigh = trn_out_neigh max_sent = trn_sent print('Max node number: {}, while max allowed is {}'.format( max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format( max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format( max_out_neigh, FLAGS.max_out_neigh_num)) print('Max entity size: {}, truncated to {}'.format( max_sent, FLAGS.max_entity_size)) word_vocab = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=False) devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=False) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode='train') with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode='evaluate') initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs last_step = 0 total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() _, loss_value, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss / (step - last_step), duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss / (step - last_step), duration)) sys.stdout.flush() log_file.flush() last_step = step total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'. format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") json.dump(res_dict['data'], open(FLAGS.output_path, 'w')) duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def main(FLAGS): train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path dev_path_target = FLAGS.dev_path_target test_path_target = FLAGS.test_path_target word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) os.makedirs(os.path.join(log_dir, '../result_source')) os.makedirs(os.path.join(log_dir, '../logits_source')) os.makedirs(os.path.join(log_dir, '../result_target')) os.makedirs(os.path.join(log_dir, '../logits_target')) log_dir_target = FLAGS.model_dir + '_target' if not os.path.exists(log_dir_target): os.makedirs(log_dir_target) path_prefix = log_dir + "/ESIM.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") path_prefix_target = log_dir_target + "/ESIM.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix_target + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' best_path_target = path_prefix_target + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False char_vocab = None # if os.path.exists(best_path + ".index"): print('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes)) sys.stdout.flush() print('Build SentenceMatchDataStream ... ') trainDataStream = DataStream(train_path, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = DataStream(dev_path, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() testDataStream = DataStream(test_path, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() devDataStream_target = DataStream(dev_path_target, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() testDataStream_target = DataStream(test_path_target, word_vocab=word_vocab, label_vocab=None, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = Model(num_classes, word_vocab=word_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = Model(num_classes, word_vocab=word_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1) config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(initializer) # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, testDataStream, devDataStream_target, testDataStream_target, FLAGS, best_path, best_path_target)
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir tolower = FLAGS.use_lower_letter FLAGS.rl_matches = json.loads(FLAGS.rl_matches) # if not os.path.exists(log_dir): # os.makedirs(log_dir) path_prefix = log_dir + "/TriMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=tolower) best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None print('best path:', best_path) if os.path.exists(best_path + '.data-00000-of-00001') and not (FLAGS.create_new_model): print('Using pretrained model') has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2', tolower=tolower) char_vocab = Vocab(char_path, fileformat='txt2', tolower=tolower) if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2', tolower=tolower) if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2', tolower=tolower) else: print('Creating new model') print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER, tolower=tolower) if FLAGS.use_options: all_labels = ['0', '1'] print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) # for word in all_labels: # print('label',word) # input('check') label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2, tolower=tolower) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim, tolower=tolower) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs, dim=FLAGS.POS_dim, tolower=tolower) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs, dim=FLAGS.NER_dim, tolower=tolower) NER_vocab.dump_to_txt2(NER_path) print('all_labels:', label_vocab) print('has pretrained model:', has_pre_trained_model) # for word in word_vocab.word_vecs: print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build TriMatchDataStream ... ') gen_concat_mat = False gen_split_mat = False if FLAGS.matching_option == 7: gen_concat_mat = True if FLAGS.concat_context: gen_split_mat = True trainDataStream = TriMatchDataStream( train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) devDataStream = TriMatchDataStream( dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) testDataStream = TriMatchDataStream( test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, verbose=FLAGS.verbose, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) tf.summary.scalar("Training Loss", train_graph.get_loss() ) # Add a scalar summary for the snapshot loss. if FLAGS.verbose: valid_graph = train_graph else: # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=( not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): # print(var.name,var.get_shape().as_list()) if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) # input('check') config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() sub_loss_counter = 0.0 for step in range(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, sent3_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, word_idx_3_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, char_matrix_idx_3_batch, sent1_length_batch, sent2_length_batch, sent3_length_batch, sent1_char_length_batch, sent2_char_length_batch, sent3_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, concat_mat_batch, split_mat_batch_q, split_mat_batch_c) = cur_batch # print(label_id_batch) if FLAGS.verbose: print(label_id_batch) print(sent1_length_batch) print(sent2_length_batch) print(sent3_length_batch) # print(word_idx_1_batch) # print(word_idx_2_batch) # print(word_idx_3_batch) # print(sent1_batch) # print(sent2_batch) # print(sent3_batch) print(concat_mat_batch) input('check') feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_passage_lengths(): sent1_length_batch, train_graph.get_question_lengths(): sent2_length_batch, train_graph.get_choice_lengths(): sent3_length_batch, train_graph.get_in_passage_words(): word_idx_1_batch, train_graph.get_in_question_words(): word_idx_2_batch, train_graph.get_in_choice_words(): word_idx_3_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_passage_char_lengths( )] = sent1_char_length_batch feed_dict[train_graph.get_question_char_lengths( )] = sent2_char_length_batch feed_dict[train_graph.get_choice_char_lengths( )] = sent3_char_length_batch feed_dict[train_graph.get_in_passage_chars( )] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_question_chars( )] = char_matrix_idx_2_batch feed_dict[train_graph.get_in_choice_chars( )] = char_matrix_idx_3_batch if POS_vocab is not None: feed_dict[train_graph.get_in_passage_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_question_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_passage_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_question_ners()] = NER_idx_2_batch if concat_mat_batch is not None: feed_dict[train_graph.concat_idx_mat] = concat_mat_batch if split_mat_batch_q is not None: feed_dict[train_graph.split_idx_mat_q] = split_mat_batch_q feed_dict[train_graph.split_idx_mat_c] = split_mat_batch_c if FLAGS.verbose: return_list = sess.run([ train_graph.get_train_op(), train_graph.get_loss(), train_graph.get_predictions(), train_graph.get_prob(), train_graph.all_probs, train_graph.correct ] + train_graph.matching_vectors, feed_dict=feed_dict) print(len(return_list)) _, loss_value, pred, prob, all_probs, correct = return_list[ 0:6] print('pred=', pred) print('prob=', prob) print('logits=', all_probs) print('correct=', correct) for val in return_list[6:]: if isinstance(val, list): print('list len ', len(val)) for objj in val: print('this shape=', val.shape) print('this shape=', val.shape) # print(val) input('check') else: _, loss_value = sess.run( [train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value sub_loss_counter += loss_value if step % int(FLAGS.display_every) == 0: print('{},{} '.format(step, sub_loss_counter), end="") sys.stdout.flush() sub_loss_counter = 0.0 # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') if FLAGS.predict_val: outpath = path_prefix + '.iter%d' % (step) + '.probs' else: outpath = None accuracy = evaluate(devDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options, outpath=outpath, mode='prob') print("Current accuracy on dev set is %.2f" % accuracy) if accuracy >= best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print('saving the current model.') accuracy = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options, outpath=outpath, mode='prob') print("Current accuracy on test set is %.2f" % accuracy) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options) print("Accuracy for test set is %.2f" % accuracy)
def main(_): if FLAGS.is_shuffle == 'True': FLAGS.is_shuffle = True else: FLAGS.is_shuffle = False print('Configuration') log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") output_res_file = open('../result/' + FLAGS.run_id, 'wt') while (Get_Next_box_size(FLAGS.start_batch) == True): output_res_file.write('Q' + str(FLAGS) + '\n') print('Q' + str(FLAGS)) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path best_path = path_prefix + '.best.model' #zero_pad = True zero_pad = False if FLAGS.prediction_mode == 'list_wise' and FLAGS.loss_type == 'list_mle': zero_pad = True trainDataStream = SentenceMatchDataStream(train_path, is_training=True, isShuffle=FLAGS.is_shuffle, isLoop=True, isSort=True, zero_pad=zero_pad, is_ndcg=FLAGS.is_ndcg) #isShuggle must be true because it dtermines we are training or not. #train_testDataStream = SentenceMatchDataStream(train_path, isShuffle=False, isLoop=True, isSort=True) testDataStream = SentenceMatchDataStream(test_path, is_training=False, isShuffle=False, isLoop=True, isSort=True, is_ndcg=FLAGS.is_ndcg) devDataStream = SentenceMatchDataStream(dev_path, is_training=False, isShuffle=False, isLoop=True, isSort=True, is_ndcg=FLAGS.is_ndcg) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) # print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) # print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) # print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() output_res_index = 1 # best_test_acc = 0 max_test_ndcg = np.zeros(10) max_valid = np.zeros(10) max_test = np.zeros(10) # max_dev_ndcg = 0 while output_res_index <= FLAGS.iter_count: # st_cuda = '' ssst = FLAGS.run_id ssst += str(FLAGS.start_batch) # output_res_file = open('../result/' + ssst + '.'+ st_cuda + str(output_res_index), 'wt') # output_sentence_file = open('../result/' + ssst + '.'+ st_cuda + str(output_res_index) + "S", 'wt') # output_train_file = open('../result/' + ssst + '.'+ st_cuda + str(output_res_index) + "T", 'wt') # output_sentences = [] output_res_index += 1 # output_res_file.write(str(FLAGS) + '\n\n') # stt = str (FLAGS) # best_dev_acc = 0.0 init_scale = 0.001 with tf.Graph().as_default(): # tf.set_random_seed(0) # np.random.seed(123) input_dim = 136 if train_path.find("2008") > 0: input_dim = 46 initializer = tf.random_uniform_initializer( -init_scale, init_scale) with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph( num_classes=3, is_training=True, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, prediction_mode=FLAGS.prediction_mode, q_count=FLAGS.question_count_per_batch, loss_type=FLAGS.loss_type, pos_avg=FLAGS.pos_avg, input_dim=input_dim) tf.summary.scalar("Training Loss", train_graph.get_loss( )) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph( num_classes=3, is_training=True, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, prediction_mode=FLAGS.prediction_mode, q_count=1, loss_type=FLAGS.loss_type, pos_avg=FLAGS.pos_avg, input_dim=input_dim) # tf.set_random_seed(123) # np.random.seed(123) initializer = tf.global_variables_initializer() # tf.set_random_seed(123) # np.random.seed(123) vars_ = {} #for var in tf.all_variables(): for var in tf.global_variables(): vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) max_valid_iter = np.zeros(10) max_test_ndcg_iter = np.zeros(10) with tf.Session() as sess: # tf.set_random_seed(123) # np.random.seed(123) sess.run(initializer) train_size = trainDataStream.get_num_batch() max_steps = (train_size * FLAGS.max_epochs ) // FLAGS.question_count_per_batch epoch_size = max_steps // (FLAGS.max_epochs) + 1 total_loss = 0.0 start_time = time.time() for step in range(max_steps): # read data _truth = [] _input_vector = [] _mask = [] for i in range(FLAGS.question_count_per_batch): cur_batch, batch_index = trainDataStream.nextBatch( ) (label_id_batch, input_vector_batch, mask_batch) = cur_batch if FLAGS.prediction_mode == 'list_wise' and FLAGS.loss_type == 'list_mle': label_id_batch, input_vector_batch = sort_mle( label_id_batch, input_vector_batch) _truth.append(label_id_batch) _input_vector.append(input_vector_batch) _mask.append(mask_batch) #print (_truth) feed_dict = { train_graph.get_truth(): tuple(_truth), train_graph.get_input_vector(): tuple(_input_vector), train_graph.get_mask(): tuple(_mask) } _, loss_value = sess.run([ train_graph.get_train_op(), train_graph.get_loss() ], feed_dict=feed_dict) #print (loss_value) #print (sess.run([train_graph.truth, train_graph.soft_truth], feed_dict=feed_dict)) #loss_value = sess.run([train_graph.logits1], feed_dict=feed_dict) import math if math.isnan(loss_value): print(step) print( sess.run([ train_graph.truth, train_graph.mask, train_graph.mask2, train_graph.mask01 ], feed_dict=feed_dict)) total_loss += loss_value if (step + 1) % epoch_size == 0 or (step + 1) == max_steps: if (step + 1) == max_steps: print(total_loss) # duration = time.time() - start_time # start_time = time.time() # total_loss = 0.0 for ndcg_ind in range(10): v_map = evaluate(devDataStream, valid_graph, sess, is_ndcg=FLAGS.is_ndcg, top_k=ndcg_ind) if v_map > max_valid[ndcg_ind]: max_valid[ndcg_ind] = v_map flag_valid = False if v_map > max_valid_iter[ndcg_ind]: max_valid_iter[ndcg_ind] = v_map flag_valid = True te_map = evaluate(testDataStream, valid_graph, sess, is_ndcg=FLAGS.is_ndcg, flag_valid=flag_valid, top_k=ndcg_ind) if te_map > max_test[ndcg_ind]: max_test[ndcg_ind] = te_map if flag_valid == True: # if te_map > max_test_ndcg[ndcg_ind] and FLAGS.store_best == True: # #best_test_acc = my_map # saver.save(sess, best_path) # if te_map > max_test_ndcg[ndcg_ind]: # max_test_ndcg[ndcg_ind] = te_map # if te_map > max_test_ndcg_iter[ndcg_ind]: # max_test_ndcg_iter[ndcg_ind] = te_map max_test_ndcg_iter[ndcg_ind] = te_map #print ("{} - {}".format(v_map, my_map)) for ndcg_ind in range(10): if max_test_ndcg_iter[ndcg_ind] > max_test_ndcg[ ndcg_ind]: max_test_ndcg[ndcg_ind] = max_test_ndcg_iter[ ndcg_ind] #print (total_loss) print("{}-{}: {}".format(FLAGS.start_batch, output_res_index - 1, max_test_ndcg_iter)) output_res_file.write("{}-{}: {}\n".format(FLAGS.start_batch, output_res_index - 1, max_test_ndcg_iter)) print("*{}-{}: {}-{}-{}".format(FLAGS.fold, FLAGS.start_batch, max_valid, max_test, max_test_ndcg)) output_res_file.write("{}-{}: {}-{}-{}\n".format( FLAGS.fold, FLAGS.start_batch, max_valid, max_test, max_test_ndcg)) FLAGS.start_batch += FLAGS.step_batch output_res_file.close()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") word_vocab_enc = None word_vocab_dec = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape)) word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") allEdgelabels = set([line.strip().split()[0] \ for line in open(FLAGS.edgelabel_vocab, 'rU')]) edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab SRC size {}'.format(word_vocab_enc.vocab_size)) print('word vocab TGT size {}'.format(word_vocab_dec.vocab_size)) sys.stdout.flush() print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_from_fof(FLAGS.train_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_from_fof(FLAGS.test_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of test samples: {}'.format(len(testset))) max_node = max(trn_node, tst_node) max_in_neigh = max(trn_in_neigh, tst_in_neigh) max_out_neigh = max(trn_out_neigh, tst_out_neigh) max_sent = max(trn_sent, tst_sent) print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num)) print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len)) print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ('ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() for var in tf.trainable_variables(): print(var) vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) sys.stdout.flush() log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch) if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or (step != 0 and step%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) if valid_graph.mode == 'evaluate': dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") else: dev_bleu = res_dict['dev_bleu'] print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu)) saver.save(sess, best_path) best_bleu = dev_bleu FLAGS.best_bleu = dev_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None if os.path.exists(best_path): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab) print("Current accuracy is %.2f" % accuracy) if accuracy>best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab) print("Accuracy for test set is %.2f" % accuracy)
def main(_): log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading data.') FLAGS.num_relations = 2 trainset = G2S_data_stream.read_bionlp_file(FLAGS.train_path, FLAGS.train_dep_path, FLAGS) if FLAGS.dev_gen == 'shuffle': random.shuffle(trainset) elif FLAGS.dev_gen == 'last': trainset.reverse() N = int(len(trainset)*FLAGS.dev_percent) devset = trainset[:N] trainset = trainset[N:] print('Number of training samples: {}'.format(len(trainset))) print('Number of dev samples: {}'.format(len(devset))) print('Number of relations: {}'.format(FLAGS.num_relations)) word_vocab = None char_vocab = None POS_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) else: print('Collecting vocabs.') all_words = set() all_chars = set() all_poses = set() all_edgelabels = set() G2S_data_stream.collect_vocabs(trainset, all_words, all_chars, all_poses, all_edgelabels) G2S_data_stream.collect_vocabs(devset, all_words, all_chars, all_poses, all_edgelabels) print('Number of words: {}'.format(len(all_words))) print('Number of chars: {}'.format(len(all_chars))) print('Number of poses: {}'.format(len(all_poses))) print('Number of edgelabels: {}'.format(len(all_edgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=all_chars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=all_poses, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") edgelabel_vocab = Vocab(voc=all_edgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(FLAGS, trainset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=True, isLoop=True, isSort=True, is_training=True) devDataStream = G2S_data_stream.G2SDataStream(FLAGS, devset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) sys.stdout.flush() FLAGS.trn_bch_num = trainDataStream.get_num_batch() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab, FLAGS, mode='train') with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab, FLAGS, mode='evaluate') initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 1e-5: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, FLAGS)['dev_f1'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs last_step = 0 total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() _, _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True) total_loss += cur_loss if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss/(step-last_step), duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss/(step-last_step), duration)) sys.stdout.flush() log_file.flush() last_step = step total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, FLAGS) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_f1'] dev_precision = res_dict['dev_precision'] dev_recall = res_dict['dev_recall'] print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev F1 = %.4f, P = %.4f, R = %.4f' % (dev_accu, dev_precision, dev_recall)) log_file.write('Dev F1 = %.4f, P = %.4f, R = %.4f\n' % (dev_accu, dev_precision, dev_recall)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() start_time = time.time() log_file.close()
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir result_dir = '../result' if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(result_dir): os.makedirs(result_dir) path_prefix = log_dir + "/Han.{}".format(FLAGS.suffix) save_namespace(FLAGS, path_prefix + ".config.json") word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' label_path = path_prefix + ".label_vocab" print('Collect words and labels ...') (all_words, all_labels) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build HanDataStream ... ') trainDataStream = HanDataStream(inpath=train_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, max_sent_length=FLAGS.max_sent_length) devDataStream = HanDataStream(inpath=dev_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, max_sent_length=FLAGS.max_sent_length) testDataStream = HanDataStream(inpath=test_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=True, batch_size = FLAGS.batch_size) tf.summary.scalar("Training Loss", train_graph.loss) # Add a scalar summary for the snapshot loss. print("Train Graph Build") with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=False, batch_size = 1) print ("dev Graph Build") initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) output_res_file = open(result_dir + '/' + FLAGS.suffix, 'wt') output_res_file.write(str(FLAGS)) with tf.Session() as sess: sess.run(initializer) train_size = trainDataStream.get_num_instance() max_steps = (train_size * FLAGS.max_epochs) // FLAGS.batch_size epoch_size = max_steps // (FLAGS.max_epochs) # + 1 total_loss = 0.0 start_time = time.time() best_accuracy = 0 for step in range(max_steps): # read data # _truth = [] # _sents_length = [] # _in_text_words = [] # for i in range(FLAGS.batch_size): # cur_instance, instance_index = trainDataStream.nextInstance () # (label,text,label_id, word_idx, sents_length) = cur_instance # # _truth.append(label_id) # _sents_length.append(sents_length) # _in_text_words.append(word_idx) # # feed_dict = { # train_graph.truth: np.array(_truth), # train_graph.sents_length: tuple(_sents_length), # train_graph.in_text_words: tuple(_in_text_words), # } feed_dict = get_feed_dict(data_stream=trainDataStream, graph=train_graph, batch_size=FLAGS.batch_size, is_testing=False) _, loss_value, _score = sess.run([train_graph.train_op, train_graph.loss , train_graph.batch_class_scores], feed_dict=feed_dict) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() if (step + 1) % epoch_size == 0 or (step + 1) == max_steps: # print(total_loss) duration = time.time() - start_time start_time = time.time() print(duration, step, "Loss: ", total_loss) output_res_file.write('\nStep %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. output_res_file.write('valid- ') dev_accuracy = evaluate(devDataStream, valid_graph, sess) output_res_file.write("%.2f\n" % dev_accuracy) print("Current dev accuracy is %.2f" % dev_accuracy) if dev_accuracy > best_accuracy: best_accuracy = dev_accuracy saver.save(sess, best_path) output_res_file.write('test- ') test_accuracy = evaluate(testDataStream, valid_graph, sess) print("Current test accuracy is %.2f" % test_accuracy) output_res_file.write("%.2f\n" % test_accuracy) output_res_file.close() sys.stdout.flush()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) init_model_prefix = FLAGS.init_model # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train" log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.train_path, isLower=FLAGS.isLower) if FLAGS.max_answer_len > train_ans_len: FLAGS.max_answer_len = train_ans_len else: trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.train_path, isLower=FLAGS.isLower) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.test_path, isLower=FLAGS.isLower) else: testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.test_path, isLower=FLAGS.isLower) print('Number of test samples: {}'.format(len(testset))) max_actual_len = max(train_ans_len, test_ans_len) print('Max answer length: {}, truncated to {}'.format( max_actual_len, FLAGS.max_answer_len)) word_vocab = None POS_vocab = None NER_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(init_model_prefix + ".best.model.index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(init_model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(init_model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(init_model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allPOSs: {}'.format(len(allPOSs))) print('Number of allNERs: {}'.format(len(allNERs))) if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") if FLAGS.with_NER: NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build') NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode="rl_train_for_phrase") with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode="decode") initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + init_model_prefix + ".best.model") saver.restore(sess, init_model_prefix + ".best.model") print("DONE!") sys.stdout.flush() # for first-time rl training, we get the current BLEU score print("First-time rl training, get the current BLEU score on dev") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS) print('First-time bleu = %.4f' % best_bleu) log_file.write('First-time bleu = %.4f\n' % best_bleu) print('Start the training loop.') sys.stdout.flush() train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.with_baseline: # greedy search (greedy_sentences, _, _, _) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="greedy") if FLAGS.with_target_lattice: (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx, sampled_generator_output_idx ) = cur_batch.sample_a_partition() else: # multinomial sampling (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx, sampled_generator_output_idx) = NP2P_beam_decoder.search( sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="multinomial") # calculate rewards rewards = [] for i in xrange(cur_batch.batch_size): # print(sampled_sentences[i]) # print(sampled_generator_input_idx[i]) # print(sampled_generator_output_idx[i]) cur_toks = cur_batch.instances[i][1].tokText.split() # r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3) r = 1.0 b = 0.0 if FLAGS.with_baseline: b = sentence_bleu([cur_toks], greedy_sentences[i].split(), smoothing_function=cc.method3) # r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]]) # b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]]) rewards.append(1.0 * (r - b)) rewards = np.array(rewards, dtype=np.float32) # sys.exit(-1) # update parameters feed_dict = train_graph.run_encoder(sess, cur_batch, FLAGS, only_feed_dict=True) feed_dict[train_graph.reward] = rewards feed_dict[ train_graph.gen_input_words] = sampled_generator_input_idx feed_dict[ train_graph.in_answer_words] = sampled_generator_output_idx feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths (_, loss_value) = sess.run([train_graph.train_op, train_graph.loss], feed_dict) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') dev_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS) print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'. format(best_bleu, dev_bleu)) best_bleu = dev_bleu saver.save(sess, best_path) # TODO: save model duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def main(FLAGS): train_path = FLAGS.train_path dev_path = FLAGS.dev_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False char_vocab = None if os.path.exists(best_path + ".index"): has_pre_trained_model = True logger.info('Loading vocabs from a pre-trained model ...') label_vocab = Vocab(label_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2') else: logger.info('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) logger.info('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: logger.info('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) logger.info('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) num_classes = label_vocab.size() logger.info("Number of labels: {}".format(num_classes)) sys.stdout.flush() logger.info('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) logger.info('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) logger.info('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=FLAGS) logger.info('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) logger.info('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() # 初始化写日志的wirter, 并将当前TensorFlow计算图写入日志 train_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph) # valid_writer = tf.summary.FileWriter(SUMMARY_DIR + '/valid') sess.run(initializer) if has_pre_trained_model: logger.info("Restoring model from " + best_path) saver.restore(sess, best_path) logger.info("DONE!") # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, FLAGS, best_path, train_writer, label_vocab) train_writer.close()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file( FLAGS.train_path) print('Number of training samples: {}'.format(len(trainset))) print('Loading dev set.') devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file( FLAGS.test_path) print('Number of dev samples: {}'.format(len(devset))) if FLAGS.finetune_path != "": print('Loading finetune set.') ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file( FLAGS.finetune_path) print('Number of finetune samples: {}'.format(len(ftset))) else: ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0) max_node = max(trn_node, tst_node, ft_node) max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh) max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh) max_sent = max(trn_sent, tst_sent, ft_sent) print('Max node number: {}, while max allowed is {}'.format( max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format( max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format( max_out_neigh, FLAGS.max_out_neigh_num)) print('Max answer length: {}, truncated to {}'.format( max_sent, FLAGS.max_answer_len)) word_vocab = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) if ftset != None: ftDataStream = G2S_data_stream.G2SDataStream(ftset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) print('Number of instances in ftDataStream: {}'.format( ftDataStream.get_num_instance())) print('Number of batches in ftDataStream: {}'.format( ftDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ('ce_train', 'rl_train', 'transformer') valid_mode = 'evaluate' if FLAGS.mode in ( 'ce_train', 'transformer') else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) sys.stdout.flush() log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode in ('ce_train', 'rl_train', 'transformer') and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample( sess, cur_batch, FLAGS) elif FLAGS.mode in ('ce_train', 'rl_train', 'transformer'): loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 if ftset != None: best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file, ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu) else: best_accu, best_bleu = validate_and_save( sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu, best_bleu) start_time = time.time() log_file.close()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/MHQA.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') trainset, _, _ = MHQA_data_stream.read_data_file(FLAGS.train_path, FLAGS) print('Number of training samples: {}'.format(len(trainset))) print('Loading dev set.') devset, _, _ = MHQA_data_stream.read_data_file(FLAGS.dev_path, FLAGS) print('Number of dev samples: {}'.format(len(devset))) word_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars) = MHQA_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = MHQA_data_stream.DataStream(trainset, word_vocab, char_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True, has_ref=True) devDataStream = MHQA_data_stream.DataStream(devset, word_vocab, char_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True, has_ref=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, options=FLAGS, has_ref=True, is_training=True) with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, options=FLAGS, has_ref=True, is_training=False) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var print(var) saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() cur_batch = MHQA_data_stream.BatchPadded(cur_batch) _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS) total_loss += cur_loss if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 best_accu = validate_and_save(sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu) start_time = time.time() log_file.close()
def main_cv(FLAGS): # np.random.seed(FLAGS.seed) for fold in range(FLAGS.cv_folds): print("Start training fold " + str(fold)) train_path = FLAGS.cv_train_path + str(fold) + '.tsv' train_feat_path = FLAGS.cv_train_feat_path + str(fold) + '.tsv' dev_path = FLAGS.cv_dev_path + str(fold) + '.tsv' dev_feat_path = FLAGS.cv_dev_feat_path + str(fold) + '.tsv' word_vec_path = FLAGS.word_vec_path char_vec_path = FLAGS.char_vec_path log_dir = FLAGS.model_dir + '/cv_fold_' + str(fold) if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') char_vocab = None best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False if os.path.exists(best_path + ".index"): has_pre_trained_model = True print('Loading vocabs from a pre-trained model ...') label_vocab = Vocab(label_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2') else: print('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: print('Number of chars: {}'.format(len(all_chars))) if char_vec_path == "": char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) else: char_vocab = Vocab(char_vec_path, fileformat='txt3') char_vocab.dump_to_txt2(char_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: print('char_vocab shape is {}'.format(char_vocab.word_vecs.shape)) num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes)) sys.stdout.flush() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, train_feat_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = SentenceMatchDataStream(dev_path, dev_feat_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): # tf.set_random_seed(FLAGS.seed) initializer = tf.random_uniform_initializer( -init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() initializer_local = tf.local_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer, feed_dict={ train_graph.w_embedding: word_vocab.word_vecs, train_graph.c_embedding: char_vocab.word_vecs }) sess.run(initializer_local) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, FLAGS, best_path, label_vocab) print()
def main(): print(FLAGS.__dict__) log_dir = FLAGS.log_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/MHQA.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() print('device: {}, n_gpu: {}, grad_accum_steps: {}'.format( device, n_gpu, FLAGS.grad_accum_steps)) log_file.write('device: {}, n_gpu: {}, grad_accum_steps: {}\n'.format( device, n_gpu, FLAGS.grad_accum_steps)) glove_vocab = None glove_embedding = None if FLAGS.embedding_model.find('elmo') < 0: print('Loading GloVe model from: {}'.format(FLAGS.glove_path)) glove_vocab, glove_embedding = MHQA_data_stream.load_glove( FLAGS.glove_path) print('Loading train set.') trainset, _ = MHQA_data_stream.read_data_file(FLAGS.train_path, FLAGS) trainset_batches = MHQA_data_stream.make_batches(trainset, FLAGS, glove_vocab) print('Number of training samples: {}'.format(len(trainset))) print('Number of training batches: {}'.format(len(trainset_batches))) print('Loading dev set.') devset, _ = MHQA_data_stream.read_data_file(FLAGS.dev_path, FLAGS) devset_batches = MHQA_data_stream.make_batches(devset, FLAGS, glove_vocab) print('Number of dev samples: {}'.format(len(devset))) print('Number of dev batches: {}'.format(len(devset_batches))) # model print('Compiling model.') model = MHQA_model_graph.ModelGraph(FLAGS, glove_embedding) if os.path.exists(path_prefix + ".model.bin"): print('!!Existing pretrained model. Loading the model...') model.load_state_dict(torch.load(path_prefix + ".model.bin")) model.to(device) # pretrained performance best_accu = 0.0 if os.path.exists(path_prefix + ".model.bin"): best_accu = FLAGS.best_accu if 'best_accu' in FLAGS.__dict__ and abs(FLAGS.best_accu) > 1e-4 \ else evaluate_dataset(model, devset_batches) FLAGS.best_accu = best_accu print("!!Accuracy for pretrained model is {}".format(best_accu)) # optimizer train_updates = len(trainset_batches) * FLAGS.num_epochs if FLAGS.grad_accum_steps > 1: train_updates = train_updates // FLAGS.grad_accum_steps if FLAGS.optim == 'bertadam': optimizer = BertAdam(model.parameters(), lr=FLAGS.learning_rate, warmup=FLAGS.warmup_proportion, t_total=train_updates) elif FLAGS.optim == 'adam': optimizer = Adam(model.parameters(), lr=FLAGS.learning_rate, weight_decay=FLAGS.lambda_l2) else: assert False, 'unsupported optimizer type: {}'.format(FLAGS.optim) print('Start the training loop, total *updating* steps = {}'.format( train_updates)) finished_steps, finished_epochs = 0, 0 train_batch_ids = list(range(0, len(trainset_batches))) model.train() while finished_epochs < FLAGS.num_epochs: epoch_start = time.time() epoch_loss = [] print('Current epoch takes {} steps'.format(len(train_batch_ids))) random.shuffle(train_batch_ids) start_time = time.time() for id in train_batch_ids: ori_batch = trainset_batches[id] batch = {k: v.to(device) if type(v) == torch.Tensor else v \ for k, v in ori_batch.items()} outputs = model(batch) loss = outputs['loss'] epoch_loss.append(loss.item()) if n_gpu > 1: loss = loss.mean() if FLAGS.grad_accum_steps > 1: loss = loss / FLAGS.grad_accum_steps loss.backward() # just calculate gradient finished_steps += 1 if finished_steps % FLAGS.grad_accum_steps == 0: optimizer.step() optimizer.zero_grad() if finished_steps % 100 == 0: print('{} '.format(finished_steps), end="") sys.stdout.flush() if torch.cuda.is_available(): torch.cuda.empty_cache() # Save a checkpoint and evaluate the model periodically. if finished_steps > 0 and finished_steps % 1000 == 0: best_accu = validate_and_save(model, devset_batches, log_file, best_accu) duration = time.time() - start_time print('Training loss = %.2f (%.3f sec)' % (float(sum(epoch_loss)), duration)) log_file.write('Training loss = %.2f (%.3f sec)\n' % (float(sum(epoch_loss)), duration)) finished_epochs += 1 best_accu = validate_and_save(model, devset_batches, log_file, best_accu) log_file.close()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.train_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.train_path, isLower=FLAGS.isLower) else: trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.test_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.test_path, isLower=FLAGS.isLower) else: testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of test samples: {}'.format(len(testset))) max_actual_len = max(train_ans_len, test_ans_len) print('Max answer length: {}, truncated to {}'.format( max_actual_len, FLAGS.max_answer_len)) word_vocab = None POS_vocab = None NER_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allPOSs: {}'.format(len(allPOSs))) print('Number of allNERs: {}'.format(len(allNERs))) if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") if FLAGS.with_NER: NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build') NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ( 'ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_2( sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) if valid_graph.mode == 'evaluate': dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'. format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") else: dev_bleu = res_dict['dev_bleu'] print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'. format(best_bleu, dev_bleu)) saver.save(sess, best_path) best_bleu = dev_bleu FLAGS.best_bleu = dev_bleu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def main(FLAGS): tf.logging.set_verbosity(tf.logging.INFO) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path kg_path = FLAGS.kg_path wordnet_path = FLAGS.wordnet_path lemma_vec_path = FLAGS.lemma_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) os.makedirs(os.path.join(log_dir, '../result')) os.makedirs(os.path.join(log_dir, '../logits')) path_prefix = log_dir + "/KEIM.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') lemma_vocab = Vocab(lemma_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" char_vocab = None tf.logging.info('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) tf.logging.info('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: tf.logging.info('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) tf.logging.info('word_vocab shape is {}'.format( word_vocab.word_vecs.shape)) tf.logging.info('lemma_word_vocab shape is {}'.format( lemma_vocab.word_vecs.shape)) num_classes = label_vocab.size() tf.logging.info("Number of labels: {}".format(num_classes)) sys.stdout.flush() with open(wordnet_path, 'rb') as f: wordnet_vocab = pkl.load(f) tf.logging.info('wordnet_vocab shape is {}'.format(len(wordnet_vocab))) with open(kg_path, 'rb') as f: kg = pkl.load(f) tf.logging.info('kg shape is {}'.format(len(kg))) tf.logging.info('Build SentenceMatchDataStream ... ') trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=None, kg=kg, wordnet_vocab=wordnet_vocab, lemma_vocab=lemma_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) tf.logging.info('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) tf.logging.info('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=None, kg=kg, wordnet_vocab=wordnet_vocab, lemma_vocab=lemma_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) tf.logging.info('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) tf.logging.info('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=None, kg=kg, wordnet_vocab=wordnet_vocab, lemma_vocab=lemma_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) tf.logging.info('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) tf.logging.info('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() # initializer = tf.truncated_normal_initializer(stddev=0.02) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = Model(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, lemma_vocab=lemma_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = Model(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, lemma_vocab=lemma_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): # if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1) config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(initializer) # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, testDataStream, FLAGS, best_path)