def load_model (model_prefix, word_vocab, batch_size): FLAGS = load_namespace(model_prefix + ".config.json") label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') num_classes = label_vocab.size() best_path = model_prefix + ".best.model" with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=False, batch_size=batch_size) vars_ = {} print ("ValidGraph Build") for var in tf.global_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) config = tf.ConfigProto(intra_op_parallelism_threads=0, inter_op_parallelism_threads=0, allow_soft_placement=True) sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver.restore(sess, best_path) return valid_graph, sess, label_vocab, FLAGS
wo_maxpool_match = False if hasattr(FLAGS, 'wo_maxpool_match'): wo_maxpool_match = FLAGS.wo_maxpool_match wo_attentive_match = False if hasattr(FLAGS, 'wo_attentive_match'): wo_attentive_match = FLAGS.wo_attentive_match wo_max_attentive_match = False if hasattr(FLAGS, 'wo_max_attentive_match'): wo_max_attentive_match = FLAGS.wo_max_attentive_match # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3') label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() POS_vocab = None NER_vocab = None char_vocab = None if with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') if with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.') parser.add_argument('--in_path', type=str, default='../data_quora/dev.tsv', help='the path to the test file.') parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.') parser.add_argument('--word_vec_path', type=str, default='../data_quora/wordvec.txt', help='word embedding file for the input file.') args, unparsed = parser.parse_known_args() # load the configuration file print('Loading configurations.') options = namespace_utils.load_namespace(args.model_prefix + ".config.json") if args.word_vec_path is None: args.word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(args.word_vec_path, fileformat='txt3') label_vocab = Vocab(args.model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() char_vocab = None if options.with_char: char_vocab = Vocab(args.model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchDataStream(args.in_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=options) print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance()))
def main_cv(FLAGS): # np.random.seed(FLAGS.seed) for fold in range(FLAGS.cv_folds): print("Start training fold " + str(fold)) train_path = FLAGS.cv_train_path + str(fold) + '.tsv' train_feat_path = FLAGS.cv_train_feat_path + str(fold) + '.tsv' dev_path = FLAGS.cv_dev_path + str(fold) + '.tsv' dev_feat_path = FLAGS.cv_dev_feat_path + str(fold) + '.tsv' word_vec_path = FLAGS.word_vec_path char_vec_path = FLAGS.char_vec_path log_dir = FLAGS.model_dir + '/cv_fold_' + str(fold) if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') char_vocab = None best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" has_pre_trained_model = False if os.path.exists(best_path + ".index"): has_pre_trained_model = True print('Loading vocabs from a pre-trained model ...') label_vocab = Vocab(label_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(char_path, fileformat='txt2') else: print('Collecting words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) if FLAGS.with_char: print('Number of chars: {}'.format(len(all_chars))) if char_vec_path == "": char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim) else: char_vocab = Vocab(char_vec_path, fileformat='txt3') char_vocab.dump_to_txt2(char_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: print('char_vocab shape is {}'.format(char_vocab.word_vecs.shape)) num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes)) sys.stdout.flush() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, train_feat_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) sys.stdout.flush() devDataStream = SentenceMatchDataStream(dev_path, dev_feat_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=FLAGS) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): # tf.set_random_seed(FLAGS.seed) initializer = tf.random_uniform_initializer( -init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=True, options=FLAGS, global_step=global_step) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=FLAGS) initializer = tf.global_variables_initializer() initializer_local = tf.local_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer, feed_dict={ train_graph.w_embedding: word_vocab.word_vecs, train_graph.c_embedding: char_vocab.word_vecs }) sess.run(initializer_local) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") # training train(sess, saver, train_graph, valid_graph, trainDataStream, devDataStream, FLAGS, best_path, label_vocab) print()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file( FLAGS.train_path) print('Number of training samples: {}'.format(len(trainset))) print('Loading dev set.') devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file( FLAGS.test_path) print('Number of dev samples: {}'.format(len(devset))) if FLAGS.finetune_path != "": print('Loading finetune set.') ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file( FLAGS.finetune_path) print('Number of finetune samples: {}'.format(len(ftset))) else: ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0) max_node = max(trn_node, tst_node, ft_node) max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh) max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh) max_sent = max(trn_sent, tst_sent, ft_sent) print('Max node number: {}, while max allowed is {}'.format( max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format( max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format( max_out_neigh, FLAGS.max_out_neigh_num)) print('Max answer length: {}, truncated to {}'.format( max_sent, FLAGS.max_answer_len)) word_vocab = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) if ftset != None: ftDataStream = G2S_data_stream.G2SDataStream(ftset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) print('Number of instances in ftDataStream: {}'.format( ftDataStream.get_num_instance())) print('Number of batches in ftDataStream: {}'.format( ftDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ('ce_train', 'rl_train', 'transformer') valid_mode = 'evaluate' if FLAGS.mode in ( 'ce_train', 'transformer') else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) sys.stdout.flush() log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode in ('ce_train', 'rl_train', 'transformer') and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample( sess, cur_batch, FLAGS) elif FLAGS.mode in ('ce_train', 'rl_train', 'transformer'): loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 if ftset != None: best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file, ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu) else: best_accu, best_bleu = validate_and_save( sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu, best_bleu) start_time = time.time() log_file.close()
mode = args.mode print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) if args.beam_size != -1: FLAGS.beam_size = args.beam_size # load vocabs print('Loading vocabs.') enc_word_vocab = dec_word_vocab = char_vocab = POS_vocab = NER_vocab = None if FLAGS.with_word: enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2') print('enc_word_vocab: {}'.format(enc_word_vocab.word_vecs.shape)) dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2') print('dec_word_vocab: {}'.format(dec_word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) print('Loading test set.') if FLAGS.infile_format == 'fof':
model_prefix = args.model_prefix in_path = args.in_path cache_size = args.cache_size use_dep = args.decode print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = char_vocab = POS_vocab = NER_vocab = None word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2') print('action_vocab: {}'.format(action_vocab.word_vecs.shape)) feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2') print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape)) print('Loading test set.') if use_dep: testset = NP2P_data_stream.read_Testset(in_path)
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_from_fof( FLAGS.train_path, FLAGS) else: trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_file( FLAGS.train_path, FLAGS) random.shuffle(trainset) devset = trainset[:200] trainset = trainset[200:] print('Number of training samples: {}'.format(len(trainset))) print('Number of dev samples: {}'.format(len(devset))) max_node = trn_node max_in_neigh = trn_in_neigh max_out_neigh = trn_out_neigh max_sent = trn_sent print('Max node number: {}, while max allowed is {}'.format( max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format( max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format( max_out_neigh, FLAGS.max_out_neigh_num)) print('Max entity size: {}, truncated to {}'.format( max_sent, FLAGS.max_entity_size)) word_vocab = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=False) devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=False) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode='train') with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode='evaluate') initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs last_step = 0 total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() _, loss_value, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss / (step - last_step), duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss / (step - last_step), duration)) sys.stdout.flush() log_file.flush() last_step = step total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'. format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") json.dump(res_dict['data'], open(FLAGS.output_path, 'w')) duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
parser.add_argument('--word_vec_path', type=str, help='word embedding file for the input file.') args, unparsed = parser.parse_known_args() # load the configuration file print('Loading configurations.') options = namespace_utils.load_namespace(args.model_prefix + "ESIM.xnli.config.json") if args.word_vec_path is None: args.word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(args.word_vec_path, fileformat='txt3') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('Build DataStream ... ') testDataStream = DataStream(args.in_path, word_vocab=word_vocab, label_vocab=None, isShuffle=False, isLoop=True, isSort=True, options=options) print('Number of instances in devDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush()
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None if os.path.exists(best_path): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab) print("Current accuracy is %.2f" % accuracy) if accuracy>best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab) print("Accuracy for test set is %.2f" % accuracy)
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir result_dir = '../result' if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(result_dir): os.makedirs(result_dir) path_prefix = log_dir + "/Han.{}".format(FLAGS.suffix) save_namespace(FLAGS, path_prefix + ".config.json") word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' label_path = path_prefix + ".label_vocab" print('Collect words and labels ...') (all_words, all_labels) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) label_vocab.dump_to_txt2(label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build HanDataStream ... ') trainDataStream = HanDataStream(inpath=train_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=True, isLoop=True, max_sent_length=FLAGS.max_sent_length) devDataStream = HanDataStream(inpath=dev_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, max_sent_length=FLAGS.max_sent_length) testDataStream = HanDataStream(inpath=test_path, word_vocab=word_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) with tf.Graph().as_default(): initializer = tf.contrib.layers.xavier_initializer() with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=True, batch_size = FLAGS.batch_size) tf.summary.scalar("Training Loss", train_graph.loss) # Add a scalar summary for the snapshot loss. print("Train Graph Build") with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, lambda_l2=FLAGS.lambda_l2, context_lstm_dim=FLAGS.context_lstm_dim, is_training=False, batch_size = 1) print ("dev Graph Build") initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) output_res_file = open(result_dir + '/' + FLAGS.suffix, 'wt') output_res_file.write(str(FLAGS)) with tf.Session() as sess: sess.run(initializer) train_size = trainDataStream.get_num_instance() max_steps = (train_size * FLAGS.max_epochs) // FLAGS.batch_size epoch_size = max_steps // (FLAGS.max_epochs) # + 1 total_loss = 0.0 start_time = time.time() best_accuracy = 0 for step in range(max_steps): # read data # _truth = [] # _sents_length = [] # _in_text_words = [] # for i in range(FLAGS.batch_size): # cur_instance, instance_index = trainDataStream.nextInstance () # (label,text,label_id, word_idx, sents_length) = cur_instance # # _truth.append(label_id) # _sents_length.append(sents_length) # _in_text_words.append(word_idx) # # feed_dict = { # train_graph.truth: np.array(_truth), # train_graph.sents_length: tuple(_sents_length), # train_graph.in_text_words: tuple(_in_text_words), # } feed_dict = get_feed_dict(data_stream=trainDataStream, graph=train_graph, batch_size=FLAGS.batch_size, is_testing=False) _, loss_value, _score = sess.run([train_graph.train_op, train_graph.loss , train_graph.batch_class_scores], feed_dict=feed_dict) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() if (step + 1) % epoch_size == 0 or (step + 1) == max_steps: # print(total_loss) duration = time.time() - start_time start_time = time.time() print(duration, step, "Loss: ", total_loss) output_res_file.write('\nStep %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. output_res_file.write('valid- ') dev_accuracy = evaluate(devDataStream, valid_graph, sess) output_res_file.write("%.2f\n" % dev_accuracy) print("Current dev accuracy is %.2f" % dev_accuracy) if dev_accuracy > best_accuracy: best_accuracy = dev_accuracy saver.save(sess, best_path) output_res_file.write('test- ') test_accuracy = evaluate(testDataStream, valid_graph, sess) print("Current test accuracy is %.2f" % test_accuracy) output_res_file.write("%.2f\n" % test_accuracy) output_res_file.close() sys.stdout.flush()
### print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load vocabs print('Loading vocabs.') word_vocab = char_vocab = POS_vocab = NER_vocab = None if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) if FLAGS.with_template: template_vocab = Vocab(model_prefix + ".template_vocab", fileformat='txt2') print('template_vocab: {}'.format(template_vocab.word_vecs.shape))
out_path = args.out_path #print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = MHQA_trainer.enrich_options(FLAGS) if FLAGS.max_passage_size < 3000: FLAGS.max_passage_size = 3000 print('Maximal passage size {}'.format(FLAGS.max_passage_size)) # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) subset_ids = json.load(codecs.open('/u/nalln478/ws/exp.multihop_qa/data.wikihop/distance_subset.json', 'rU', 'utf-8')) subset_ids = set(subset_ids) print('Loading test set from {}.'.format(in_path)) testset, test_filtered, _ = MHQA_data_stream.read_data_file(in_path, FLAGS, subset_ids=subset_ids) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') testDataStream = MHQA_data_stream.DataStream(testset, word_vocab, char_vocab, options=FLAGS,
reasonet_lambda = 10 reasonet_terminate_mode = 'original' reasonet_keep_first = True reasonet_logit_combine = 'sum' if reasonet_training: reasonet_steps = FLAGS.reasonet_steps reasonet_hidden_dim = FLAGS.reasonet_hidden_dim reasonet_lambda = FLAGS.reasonet_lambda reasonet_terminate_mode = FLAGS.reasonet_terminate_mode reasonet_keep_first = FLAGS.reasonet_keep_first reasonet_logit_combine = FLAGS.reasonet_logit_combine # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=FLAGS.use_lower_letter) label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2', tolower=FLAGS.use_lower_letter) print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() POS_vocab = None NER_vocab = None char_vocab = None if with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2', tolower=FLAGS.use_lower_letter)
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) init_model_prefix = FLAGS.init_model # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train" log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.train_path, isLower=FLAGS.isLower) if FLAGS.max_answer_len > train_ans_len: FLAGS.max_answer_len = train_ans_len else: trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.train_path, isLower=FLAGS.isLower) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.test_path, isLower=FLAGS.isLower) else: testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.test_path, isLower=FLAGS.isLower) print('Number of test samples: {}'.format(len(testset))) max_actual_len = max(train_ans_len, test_ans_len) print('Max answer length: {}, truncated to {}'.format( max_actual_len, FLAGS.max_answer_len)) word_vocab = None POS_vocab = None NER_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(init_model_prefix + ".best.model.index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(init_model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(init_model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(init_model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allPOSs: {}'.format(len(allPOSs))) print('Number of allNERs: {}'.format(len(allNERs))) if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") if FLAGS.with_NER: NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build') NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode="rl_train_for_phrase") with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode="decode") initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + init_model_prefix + ".best.model") saver.restore(sess, init_model_prefix + ".best.model") print("DONE!") sys.stdout.flush() # for first-time rl training, we get the current BLEU score print("First-time rl training, get the current BLEU score on dev") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS) print('First-time bleu = %.4f' % best_bleu) log_file.write('First-time bleu = %.4f\n' % best_bleu) print('Start the training loop.') sys.stdout.flush() train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.with_baseline: # greedy search (greedy_sentences, _, _, _) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="greedy") if FLAGS.with_target_lattice: (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx, sampled_generator_output_idx ) = cur_batch.sample_a_partition() else: # multinomial sampling (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx, sampled_generator_output_idx) = NP2P_beam_decoder.search( sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="multinomial") # calculate rewards rewards = [] for i in xrange(cur_batch.batch_size): # print(sampled_sentences[i]) # print(sampled_generator_input_idx[i]) # print(sampled_generator_output_idx[i]) cur_toks = cur_batch.instances[i][1].tokText.split() # r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3) r = 1.0 b = 0.0 if FLAGS.with_baseline: b = sentence_bleu([cur_toks], greedy_sentences[i].split(), smoothing_function=cc.method3) # r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]]) # b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]]) rewards.append(1.0 * (r - b)) rewards = np.array(rewards, dtype=np.float32) # sys.exit(-1) # update parameters feed_dict = train_graph.run_encoder(sess, cur_batch, FLAGS, only_feed_dict=True) feed_dict[train_graph.reward] = rewards feed_dict[ train_graph.gen_input_words] = sampled_generator_input_idx feed_dict[ train_graph.in_answer_words] = sampled_generator_output_idx feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths (_, loss_value) = sess.run([train_graph.train_op, train_graph.loss], feed_dict) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') dev_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS) print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'. format(best_bleu, dev_bleu)) best_bleu = dev_bleu saver.save(sess, best_path) # TODO: save model duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def main(_): #for x in range (100): # Generate_random_initialization() # print (FLAGS.is_aggregation_lstm, FLAGS.context_lstm_dim, FLAGS.context_layer_num, FLAGS. aggregation_lstm_dim, FLAGS.aggregation_layer_num, FLAGS.max_window_size, FLAGS.MP_dim) print('Configurations:') #print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3') best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None if os.path.exists(best_path): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, is_as=FLAGS.is_answer_selection) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None output_res_index = 1 while True: Generate_random_initialization() st_cuda = '' if FLAGS.is_server == True: st_cuda = str(os.environ['CUDA_VISIBLE_DEVICES']) + '.' output_res_file = open('../result/' + st_cuda + str(output_res_index), 'wt') output_res_index += 1 output_res_file.write(str(FLAGS) + '\n\n') stt = str (FLAGS) best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_bilinear_att=(FLAGS.attention_type) , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, with_aggregation_attention=not FLAGS.wo_agg_self_att, is_answer_selection= FLAGS.is_answer_selection, is_shared_attention=FLAGS.is_shared_attention, modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm , max_window_size=FLAGS.max_window_size , prediction_mode=FLAGS.prediction_mode, context_lstm_dropout=not FLAGS.wo_lstm_drop_out, is_aggregation_siamese=FLAGS.is_aggregation_siamese , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_bilinear_att=(FLAGS.attention_type) , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, with_aggregation_attention=not FLAGS.wo_agg_self_att, is_answer_selection= FLAGS.is_answer_selection, is_shared_attention=FLAGS.is_shared_attention, modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm, max_window_size=FLAGS.max_window_size , prediction_mode=FLAGS.prediction_mode, context_lstm_dropout=not FLAGS.wo_lstm_drop_out, is_aggregation_siamese=FLAGS.is_aggregation_siamese , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention) initializer = tf.global_variables_initializer() vars_ = {} #for var in tf.all_variables(): for var in tf.global_variables(): if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) with tf.Session() as sess: sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch, batch_index = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch if FLAGS.is_answer_selection == True: feed_dict[train_graph.get_question_count()] = trainDataStream.question_count(batch_index) feed_dict[train_graph.get_answer_count()] = trainDataStream.answer_count(batch_index) _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if FLAGS.is_answer_selection == True and FLAGS.is_server == False: print ("q: {} a: {} loss_value: {}".format(trainDataStream.question_count(batch_index) ,trainDataStream.answer_count(batch_index), loss_value)) if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: #print(total_loss) # Print status to stdout. duration = time.time() - start_time start_time = time.time() output_res_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) total_loss = 0.0 #Evaluate against the validation set. output_res_file.write('valid- ') my_map, my_mrr = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr)) #print ("dev map: {}".format(my_map)) #print("Current accuracy is %.2f" % accuracy) #accuracy = my_map #if accuracy>best_accuracy: # best_accuracy = accuracy # saver.save(sess, best_path) # Evaluate against the test set. output_res_file.write ('test- ') my_map, my_mrr = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}\n\n".format(my_map, my_mrr)) if FLAGS.is_server == False: print ("test map: {}".format(my_map)) #Evaluate against the train set only for final epoch. if (step + 1) == max_steps: output_res_file.write ('train- ') my_map, my_mrr = evaluate(trainDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab) output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr)) # print("Best accuracy on dev set is %.2f" % best_accuracy) # # decoding # print('Decoding on the test set:') # init_scale = 0.01 # with tf.Graph().as_default(): # initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.variable_scope("Model", reuse=False, initializer=initializer): # valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, # dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, # lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, # aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, # context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, # fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, # word_level_MP_dim=FLAGS.word_level_MP_dim, # with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, # highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, # lex_decompsition_dim=FLAGS.lex_decompsition_dim, # with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), # with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), # with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), # with_bilinear_att=(not FLAGS.wo_bilinear_att) # , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3, # with_aggregation_attention=not FLAGS.wo_agg_self_att, # is_answer_selection= FLAGS.is_answer_selection, # is_shared_attention=FLAGS.is_shared_attention, # modify_loss=FLAGS.modify_loss,is_aggregation_lstm=FLAGS.is_aggregation_lstm, # max_window_size=FLAGS.max_window_size, # prediction_mode=FLAGS.prediction_mode, # context_lstm_dropout=not FLAGS.wo_lstm_drop_out, # is_aggregation_siamese=FLAGS.is_aggregation_siamese) # # vars_ = {} # for var in tf.global_variables(): # if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue # vars_[var.name.split(":")[0]] = var # saver = tf.train.Saver(vars_) # # sess = tf.Session() # sess.run(tf.global_variables_initializer()) # step = 0 # saver.restore(sess, best_path) # # accuracy, mrr = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab # , mode='trec') # output_res_file.write("map for test set is %.2f\n" % accuracy) output_res_file.close()
help='The path to the output file.') parser.add_argument('--word_vec_path', type=str, help='word embedding file for the input file.') args, unparsed = parser.parse_known_args() # load the configuration file tf.logging.info('Loading configurations.') options = namespace_utils.load_namespace(args.model_prefix + "KEIM.snli.config.json") if args.word_vec_path is None: args.word_vec_path = options.word_vec_path # load vocabs tf.logging.info('Loading vocabs.') word_vocab = Vocab(args.word_vec_path, fileformat='txt3') tf.logging.info('word_vocab: {}'.format(word_vocab.word_vecs.shape)) lemma_vocab = Vocab(options.lemma_vec_path, fileformat='txt3') tf.logging.info('lemma_vocab: {}'.format(lemma_vocab.word_vecs.shape)) char_vocab = None if options.with_char: char_vocab = Vocab(args.model_prefix + ".char_vocab", fileformat='txt2') tf.logging.info('char_vocab: {}'.format(char_vocab.word_vecs.shape)) tf.logging.info('Build SentenceMatchDataStream ... ') testDataStream = DataStream(args.in_path, word_vocab=word_vocab, char_vocab=char_vocab,
cache_size = args.cache_size use_dep = args.decode oracle.utils.pushidx_feat_num = (1 + args.cache_size) * 5 print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = char_vocab = POS_vocab = NER_vocab = None word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) FLAGS.feat_num = 72 + args.cache_size * 5 action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2') print('action_vocab: {}'.format(action_vocab.word_vecs.shape)) feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2') print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape)) print('Loading test set.') if use_dep:
#########################################################################main(FLAGS) # DONOTCHANGE: Reserved for nsml train_path = DATASET_PATH log_dir = config.model_dir char_vocab = None # if os.path.exists(best_path + ".index"): if config.mode == 'train': print('Collecting words, chars and labels ...') # (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path) (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs_kin(train_path) print('Number of words: {}'.format(len(all_words))) label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2) # label_vocab.dump_to_txt2(label_path) word_vocab = Vocab(fileformat='voc', voc=all_words, dim=config.word_emb_dim) if config.with_char: print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=config.char_emb_dim) # char_vocab.dump_to_txt2(char_path) else: print('test seq ') word_vocab = [] label_vocab = []
def main(_): print('Configurations:') print(FLAGS) # 打印各个参数 root_path = FLAGS.root_path train_path = root_path + FLAGS.train_path dev_path = root_path + FLAGS.dev_path test_path = root_path + FLAGS.test_path word_vec_path = root_path + FLAGS.word_vec_path model_dir = root_path + FLAGS.model_dir if tf.gfile.Exists(model_dir + '/mnist_with_summaries'): print("delete summaries") tf.gfile.DeleteRecursively(model_dir + '/mnist_with_summaries') if not os.path.exists(model_dir): os.makedirs(model_dir) path_prefix = model_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # 保存参数 best_path = path_prefix + '.best.model' label_path = path_prefix + ".label_vocab" has_pre_trained_model = False ckpt = tf.train.get_checkpoint_state(model_dir) if ckpt and ckpt.model_checkpoint_path: print("-------has_pre_trained_model--------") print(ckpt.model_checkpoint_path) has_pre_trained_model = True ############# build vocabs################# print('Collect words, chars and labels ...') (all_words, all_labels) = collect_vocabs(train_path) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) word_vocab = Vocab(pattern='word') # 定义一个类 word_vocab.patternWord(word_vec_path, model_dir) label_vocab = Vocab(pattern="label") label_vocab.patternLabel(all_labels, label_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = len(all_labels) if FLAGS.wo_char: char_vocab = None ##### Build SentenceMatchDataStream ################ print('Build SentenceMatchDataStream ... ') trainDataStream = SentenceMatchDataStream( train_path, word_vocab=word_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=False, max_sent_length=FLAGS.max_sent_length) devDataStream = SentenceMatchDataStream( dev_path, word_vocab=word_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=False, max_sent_length=FLAGS.max_sent_length) testDataStream = SentenceMatchDataStream( test_path, word_vocab=word_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=False, max_sent_length=FLAGS.max_sent_length) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() best_accuracy = 0.0 init_scale = 0.01 g_2 = tf.Graph() with g_2.as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = SentenceMatchModelGraph( num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, with_word=True, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) tf.summary.scalar("Training Loss", train_graph.get_loss()) with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = SentenceMatchModelGraph( num_classes, word_vocab=word_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, with_word=True, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match)) initializer = tf.global_variables_initializer() saver = tf.train.Saver() # vars_ = {} # for var in tf.global_variables(): # if "word_embedding" in var.name: continue # vars_[var.name.split(":")[0]] = var # saver = tf.train.Saver(vars_) sess = tf.Session() merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter( model_dir + '/mnist_with_summaries/train', sess.graph) sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in iter(range(max_steps)): cur_batch = trainDataStream.nextBatch() (label_id_batch, word_idx_1_batch, word_idx_2_batch, sent1_length_batch, sent2_length_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, } # in_question_repres,in_ques=sess.run([train_graph.in_question_repres,train_graph.in_ques],feed_dict=feed_dict) # print(in_question_repres,in_ques) # break _, loss_value, summary = sess.run( [train_graph.get_train_op(), train_graph.get_loss(), merged], feed_dict=feed_dict) total_loss += loss_value if step % 5000 == 0: # train_writer.add_summary(summary, step) print("step:", step, "loss:", loss_value) if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess) print("Current accuracy is %.2f" % accuracy) if accuracy >= best_accuracy: print('Saving model since it\'s the best so far') best_accuracy = accuracy saver.save(sess, best_path) sys.stdout.flush() print("Best accuracy on dev set is %.2f" % best_accuracy)
def main(_): log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading data.') FLAGS.num_relations = 2 trainset = G2S_data_stream.read_bionlp_file(FLAGS.train_path, FLAGS.train_dep_path, FLAGS) if FLAGS.dev_gen == 'shuffle': random.shuffle(trainset) elif FLAGS.dev_gen == 'last': trainset.reverse() N = int(len(trainset)*FLAGS.dev_percent) devset = trainset[:N] trainset = trainset[N:] print('Number of training samples: {}'.format(len(trainset))) print('Number of dev samples: {}'.format(len(devset))) print('Number of relations: {}'.format(FLAGS.num_relations)) word_vocab = None char_vocab = None POS_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) else: print('Collecting vocabs.') all_words = set() all_chars = set() all_poses = set() all_edgelabels = set() G2S_data_stream.collect_vocabs(trainset, all_words, all_chars, all_poses, all_edgelabels) G2S_data_stream.collect_vocabs(devset, all_words, all_chars, all_poses, all_edgelabels) print('Number of words: {}'.format(len(all_words))) print('Number of chars: {}'.format(len(all_chars))) print('Number of poses: {}'.format(len(all_poses))) print('Number of edgelabels: {}'.format(len(all_edgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=all_chars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=all_poses, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") edgelabel_vocab = Vocab(voc=all_edgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(FLAGS, trainset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=True, isLoop=True, isSort=True, is_training=True) devDataStream = G2S_data_stream.G2SDataStream(FLAGS, devset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) sys.stdout.flush() FLAGS.trn_bch_num = trainDataStream.get_num_batch() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab, FLAGS, mode='train') with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab, FLAGS, mode='evaluate') initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 1e-5: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, FLAGS)['dev_f1'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs last_step = 0 total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() _, _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True) total_loss += cur_loss if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss/(step-last_step), duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss/(step-last_step), duration)) sys.stdout.flush() log_file.flush() last_step = step total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, FLAGS) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_f1'] dev_precision = res_dict['dev_precision'] dev_recall = res_dict['dev_recall'] print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev F1 = %.4f, P = %.4f, R = %.4f' % (dev_accu, dev_precision, dev_recall)) log_file.write('Dev F1 = %.4f, P = %.4f, R = %.4f\n' % (dev_accu, dev_precision, dev_recall)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() start_time = time.time() log_file.close()
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir if FLAGS.train == "sick": train_path = FLAGS.SICK_train_path dev_path = FLAGS.SICK_dev_path if FLAGS.test == "sick": test_path = FLAGS.SICK_test_path if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs parser,image_feats = None, None if FLAGS.with_dep: parser=Parser('snli') if FLAGS.with_image: image_feats=ImageFeatures() word_vocab = Vocab(word_vec_path, fileformat='txt3', parser=parser, beginning=FLAGS.beginning) #fileformat='txt3' best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" DEP_path = path_prefix + ".DEP_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None DEP_vocab = None print('has pretrained model: ', os.path.exists(best_path)) print('best_path: ' + best_path) if os.path.exists(best_path + '.meta'): has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2') char_vocab = Vocab(char_path, fileformat='txt2') if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2') if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2') else: print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER) print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim, beginning=FLAGS.beginning) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim) NER_vocab.dump_to_txt2(NER_path) print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build DataStream ... ') print('Reading trainDataStream') if not FLAGS.decoding_only: trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('Reading devDataStream') devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('Reading testDataStream') testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick")) print('save cache file') #word_vocab.parser.save_cache() #image_feats.save_feat() if not FLAGS.decoding_only: print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 if not FLAGS.decoding_only: with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, with_img_attentive_match=FLAGS.with_img_attentive_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss. with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) #, feed_dict={valid_graph.emb_init: word_vocab.word_vecs, train_graph.emb_init: word_vocab.word_vecs}) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") #if best_path.startswith('bimpm_baseline'): #best_path = best_path + '_sick' print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, sent1_char_length_batch, sent2_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, dependency1_batch, dependency2_batch, dep_con1_batch, dep_con2_batch, img_feats_batch, img_id_batch) = cur_batch feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_question_lengths(): sent1_length_batch, train_graph.get_passage_lengths(): sent2_length_batch, train_graph.get_in_question_words(): word_idx_1_batch, train_graph.get_in_passage_words(): word_idx_2_batch, #train_graph.get_emb_init(): word_vocab.word_vecs, #train_graph.get_in_question_dependency(): dependency1_batch, #train_graph.get_in_passage_dependency(): dependency2_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if FLAGS.with_dep: feed_dict[train_graph.get_in_question_dependency()] = dependency1_batch feed_dict[train_graph.get_in_passage_dependency()] = dependency2_batch feed_dict[train_graph.get_in_question_dep_con()] = dep_con1_batch feed_dict[train_graph.get_in_passage_dep_con()] = dep_con2_batch if FLAGS.with_image: feed_dict[train_graph.get_image_feats()] = img_feats_batch if char_vocab is not None: feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch if POS_vocab is not None: feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, word_vocab=word_vocab) print("Current accuracy on dev is %.2f" % accuracy) #accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab) #print("Current accuracy on train is %.2f" % accuracy_train) if accuracy>best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, lex_decompsition_dim=FLAGS.lex_decompsition_dim, with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match), with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only, with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer())#, feed_dict={valid_graph.emb_init: word_vocab.word_vecs}) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess, outpath=FLAGS.suffix+ FLAGS.train + FLAGS.test + ".result",char_vocab=char_vocab,label_vocab=label_vocab, word_vocab=word_vocab) print("Accuracy for test set is %.2f" % accuracy) accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab,word_vocab=word_vocab) print("Accuracy for train set is %.2f" % accuracy_train)
self.options.max_src_len) self.sent_inp = padding_utils.pad_2d_vals(ori_batch.sent_inp, len(ori_batch.sent_inp), self.options.max_answer_len) self.sent_out = padding_utils.pad_2d_vals(ori_batch.sent_out, len(ori_batch.sent_out), self.options.max_answer_len) if __name__ == "__main__": FLAGS = namespace_utils.load_namespace('../config.json') print('Collecting vocab') allEdgelabels = set([line.strip().split()[0] \ for line in open('../data/edgelabel_vocab.en', 'rU')]) edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') word_vocab_enc = Vocab('../data/vectors.en.st', fileformat='txt2') word_vocab_dec = Vocab('../data/vectors.de.st', fileformat='txt2') print('Loading trainset') trainset, _, _, _, _ = read_amr_file('../data/newstest2013.tok.json', FLAGS, word_vocab_enc, word_vocab_dec, None, edgelabel_vocab) print('Build DataStream ... ') trainDataStream = G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, None, edgelabel_vocab, options=FLAGS, isShuffle=True,
model_prefix = args.model_prefix in_path = args.in_path out_path = args.out_path mode = args.mode print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = G2S_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Loading test set from {}.'.format(in_path)) testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1
model_prefix = args.model_prefix in_path = args.in_path cache_size = args.cache_size use_dep = args.decode print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = char_vocab = POS_vocab = NER_vocab = None word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2') print('action_vocab: {}'.format(action_vocab.word_vecs.shape)) feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2') print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape)) print('Loading test set.') if use_dep: testset = NP2P_data_stream.read_Testset(in_path, ulfdep=args.ulf)
config_FLAGS.__dict__["in_format"] = 'tsv' word_vec_path = config_FLAGS.word_vec_path log_dir = config_FLAGS.model_dir path_prefix = os.path.join(log_dir, "SentenceMatch.{}".format(config_FLAGS.suffix)) ent_word_vocab = EntVocab(word_vec_path, fileformat='txt3') print("word_vocab shape is {}".format(ent_word_vocab.word_vecs.shape)) best_path = path_prefix + ".best.model" label_path = path_prefix + ".label_vocab" print("best_path: {}".format(best_path)) if os.path.exists(best_path + ".index"): print("Loading label vocab") label_vocab = EntVocab(label_path, fileformat='txt2') else: raise Exception("no pretrained model") num_classes = label_vocab.size() print("Number of labels: {}".format(num_classes)) global_step = tf.train.get_global_step() # define entailment model config_FLAGS = namespace_utils.load_namespace(config_path) entailment_model = SentenceMatchModelGraph(3, word_vocab=ent_word_vocab, is_training=True, options=config_FLAGS, global_step=global_step,
def main(_): print('Configurations:') print(FLAGS) train_path = FLAGS.train_path dev_path = FLAGS.dev_path test_path = FLAGS.test_path word_vec_path = FLAGS.word_vec_path log_dir = FLAGS.model_dir tolower = FLAGS.use_lower_letter FLAGS.rl_matches = json.loads(FLAGS.rl_matches) # if not os.path.exists(log_dir): # os.makedirs(log_dir) path_prefix = log_dir + "/TriMatch.{}".format(FLAGS.suffix) namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") # build vocabs word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=tolower) best_path = path_prefix + '.best.model' char_path = path_prefix + ".char_vocab" label_path = path_prefix + ".label_vocab" POS_path = path_prefix + ".POS_vocab" NER_path = path_prefix + ".NER_vocab" has_pre_trained_model = False POS_vocab = None NER_vocab = None print('best path:', best_path) if os.path.exists(best_path + '.data-00000-of-00001') and not (FLAGS.create_new_model): print('Using pretrained model') has_pre_trained_model = True label_vocab = Vocab(label_path, fileformat='txt2', tolower=tolower) char_vocab = Vocab(char_path, fileformat='txt2', tolower=tolower) if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2', tolower=tolower) if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2', tolower=tolower) else: print('Creating new model') print('Collect words, chars and labels ...') (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER, tolower=tolower) if FLAGS.use_options: all_labels = ['0', '1'] print('Number of words: {}'.format(len(all_words))) print('Number of labels: {}'.format(len(all_labels))) # for word in all_labels: # print('label',word) # input('check') label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2, tolower=tolower) label_vocab.dump_to_txt2(label_path) print('Number of chars: {}'.format(len(all_chars))) char_vocab = Vocab(fileformat='voc', voc=all_chars, dim=FLAGS.char_emb_dim, tolower=tolower) char_vocab.dump_to_txt2(char_path) if FLAGS.with_POS: print('Number of POSs: {}'.format(len(all_POSs))) POS_vocab = Vocab(fileformat='voc', voc=all_POSs, dim=FLAGS.POS_dim, tolower=tolower) POS_vocab.dump_to_txt2(POS_path) if FLAGS.with_NER: print('Number of NERs: {}'.format(len(all_NERs))) NER_vocab = Vocab(fileformat='voc', voc=all_NERs, dim=FLAGS.NER_dim, tolower=tolower) NER_vocab.dump_to_txt2(NER_path) print('all_labels:', label_vocab) print('has pretrained model:', has_pre_trained_model) # for word in word_vocab.word_vecs: print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape)) print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() print('Build TriMatchDataStream ... ') gen_concat_mat = False gen_split_mat = False if FLAGS.matching_option == 7: gen_concat_mat = True if FLAGS.concat_context: gen_split_mat = True trainDataStream = TriMatchDataStream( train_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) devDataStream = TriMatchDataStream( dev_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) testDataStream = TriMatchDataStream( test_path, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=(not FLAGS.wo_sort_instance_based_on_length), max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, max_hyp_length=FLAGS.max_hyp_length, max_choice_length=FLAGS.max_choice_length, tolower=tolower, gen_concat_mat=gen_concat_mat, gen_split_mat=gen_split_mat) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) print('Number of batches in testDataStream: {}'.format( testDataStream.get_num_batch())) sys.stdout.flush() if FLAGS.wo_char: char_vocab = None best_accuracy = 0.0 init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) # with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, verbose=FLAGS.verbose, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) tf.summary.scalar("Training Loss", train_graph.get_loss() ) # Add a scalar summary for the snapshot loss. if FLAGS.verbose: valid_graph = train_graph else: # with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=( not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): # print(var.name,var.get_shape().as_list()) if "word_embedding" in var.name: continue # if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) # input('check') config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(initializer) if has_pre_trained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() sub_loss_counter = 0.0 for step in range(max_steps): # read data cur_batch = trainDataStream.nextBatch() (label_batch, sent1_batch, sent2_batch, sent3_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, word_idx_3_batch, char_matrix_idx_1_batch, char_matrix_idx_2_batch, char_matrix_idx_3_batch, sent1_length_batch, sent2_length_batch, sent3_length_batch, sent1_char_length_batch, sent2_char_length_batch, sent3_char_length_batch, POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, concat_mat_batch, split_mat_batch_q, split_mat_batch_c) = cur_batch # print(label_id_batch) if FLAGS.verbose: print(label_id_batch) print(sent1_length_batch) print(sent2_length_batch) print(sent3_length_batch) # print(word_idx_1_batch) # print(word_idx_2_batch) # print(word_idx_3_batch) # print(sent1_batch) # print(sent2_batch) # print(sent3_batch) print(concat_mat_batch) input('check') feed_dict = { train_graph.get_truth(): label_id_batch, train_graph.get_passage_lengths(): sent1_length_batch, train_graph.get_question_lengths(): sent2_length_batch, train_graph.get_choice_lengths(): sent3_length_batch, train_graph.get_in_passage_words(): word_idx_1_batch, train_graph.get_in_question_words(): word_idx_2_batch, train_graph.get_in_choice_words(): word_idx_3_batch, # train_graph.get_question_char_lengths(): sent1_char_length_batch, # train_graph.get_passage_char_lengths(): sent2_char_length_batch, # train_graph.get_in_question_chars(): char_matrix_idx_1_batch, # train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, } if char_vocab is not None: feed_dict[train_graph.get_passage_char_lengths( )] = sent1_char_length_batch feed_dict[train_graph.get_question_char_lengths( )] = sent2_char_length_batch feed_dict[train_graph.get_choice_char_lengths( )] = sent3_char_length_batch feed_dict[train_graph.get_in_passage_chars( )] = char_matrix_idx_1_batch feed_dict[train_graph.get_in_question_chars( )] = char_matrix_idx_2_batch feed_dict[train_graph.get_in_choice_chars( )] = char_matrix_idx_3_batch if POS_vocab is not None: feed_dict[train_graph.get_in_passage_poss()] = POS_idx_1_batch feed_dict[train_graph.get_in_question_poss()] = POS_idx_2_batch if NER_vocab is not None: feed_dict[train_graph.get_in_passage_ners()] = NER_idx_1_batch feed_dict[train_graph.get_in_question_ners()] = NER_idx_2_batch if concat_mat_batch is not None: feed_dict[train_graph.concat_idx_mat] = concat_mat_batch if split_mat_batch_q is not None: feed_dict[train_graph.split_idx_mat_q] = split_mat_batch_q feed_dict[train_graph.split_idx_mat_c] = split_mat_batch_c if FLAGS.verbose: return_list = sess.run([ train_graph.get_train_op(), train_graph.get_loss(), train_graph.get_predictions(), train_graph.get_prob(), train_graph.all_probs, train_graph.correct ] + train_graph.matching_vectors, feed_dict=feed_dict) print(len(return_list)) _, loss_value, pred, prob, all_probs, correct = return_list[ 0:6] print('pred=', pred) print('prob=', prob) print('logits=', all_probs) print('correct=', correct) for val in return_list[6:]: if isinstance(val, list): print('list len ', len(val)) for objj in val: print('this shape=', val.shape) print('this shape=', val.shape) # print(val) input('check') else: _, loss_value = sess.run( [train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict) total_loss += loss_value sub_loss_counter += loss_value if step % int(FLAGS.display_every) == 0: print('{},{} '.format(step, sub_loss_counter), end="") sys.stdout.flush() sub_loss_counter = 0.0 # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() # Print status to stdout. duration = time.time() - start_time start_time = time.time() print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) total_loss = 0.0 # Evaluate against the validation set. print('Validation Data Eval:') if FLAGS.predict_val: outpath = path_prefix + '.iter%d' % (step) + '.probs' else: outpath = None accuracy = evaluate(devDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options, outpath=outpath, mode='prob') print("Current accuracy on dev set is %.2f" % accuracy) if accuracy >= best_accuracy: best_accuracy = accuracy saver.save(sess, best_path) print('saving the current model.') accuracy = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options, outpath=outpath, mode='prob') print("Current accuracy on test set is %.2f" % accuracy) print("Best accuracy on dev set is %.2f" % best_accuracy) # decoding print('Decoding on the test set:') init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = TriMatchModelGraph( num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type, lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, fix_word_vec=FLAGS.fix_word_vec, with_highway=FLAGS.with_highway, word_level_MP_dim=FLAGS.word_level_MP_dim, with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway, highway_layer_num=FLAGS.highway_layer_num, match_to_question=FLAGS.match_to_question, match_to_passage=FLAGS.match_to_passage, match_to_choice=FLAGS.match_to_choice, with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), use_options=FLAGS.use_options, num_options=num_options, with_no_match=FLAGS.with_no_match, matching_option=FLAGS.matching_option, concat_context=FLAGS.concat_context, tied_aggre=FLAGS.tied_aggre, rl_training_method=FLAGS.rl_training_method, rl_matches=FLAGS.rl_matches) vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(tf.global_variables_initializer()) step = 0 saver.restore(sess, best_path) accuracy = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, use_options=FLAGS.use_options) print("Accuracy for test set is %.2f" % accuracy)
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") word_vocab_enc = None word_vocab_dec = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape)) word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") allEdgelabels = set([line.strip().split()[0] \ for line in open(FLAGS.edgelabel_vocab, 'rU')]) edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab SRC size {}'.format(word_vocab_enc.vocab_size)) print('word vocab TGT size {}'.format(word_vocab_dec.vocab_size)) sys.stdout.flush() print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_from_fof(FLAGS.train_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_from_fof(FLAGS.test_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of test samples: {}'.format(len(testset))) max_node = max(trn_node, tst_node) max_in_neigh = max(trn_in_neigh, tst_in_neigh) max_out_neigh = max(trn_out_neigh, tst_out_neigh) max_sent = max(trn_sent, tst_sent) print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num)) print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len)) print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ('ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() for var in tf.trainable_variables(): print(var) vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) sys.stdout.flush() log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch) if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or (step != 0 and step%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) if valid_graph.mode == 'evaluate': dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") else: dev_bleu = res_dict['dev_bleu'] print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu)) saver.save(sess, best_path) best_bleu = dev_bleu FLAGS.best_bleu = dev_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def get_test_result(in_p,root_path): print('Loading configurations.') model_prefix =root_path+"/stsapp/src/logs/SentenceMatch.snli" word_vec_path = root_path+"/stsapp/src/data/snli/wordvec.txt" in_path = in_p out_path =root_path+"/stsapp/src/result.txt" print("access decoder") options = namespace_utils.load_namespace(model_prefix + ".config.json") if word_vec_path is None: word_vec_path = options.word_vec_path # load vocabs print('Loading vocabs.') word_vocab = Vocab(word_vec_path, fileformat='txt3') label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) print('label_vocab: {}'.format(label_vocab.word_vecs.shape)) num_classes = label_vocab.size() if options.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Build SentenceMatchDataStream ... ') testDataStream = SentenceMatchDataStream(in_path, word_vocab=word_vocab, char_vocab=char_vocab, label_vocab=label_vocab, isShuffle=False, isLoop=True, isSort=True, options=options) print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance())) print('Number of batches in devDataStream: {}'.format(testDataStream.get_num_batch())) sys.stdout.flush() best_path = model_prefix + ".best.model" init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) global_step = tf.train.get_or_create_global_step() with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, is_training=False, options=options) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.global_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") acc,result = train.evaluation(sess, valid_graph, testDataStream, outpath=out_path, label_vocab=label_vocab) print("Accuracy for test set is : ",colored(acc, 'green'),"\n") # print(result['probs']) return acc,result
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.train_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.train_path, isLower=FLAGS.isLower) else: trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.test_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.test_path, isLower=FLAGS.isLower) else: testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of test samples: {}'.format(len(testset))) max_actual_len = max(train_ans_len, test_ans_len) print('Max answer length: {}, truncated to {}'.format( max_actual_len, FLAGS.max_answer_len)) word_vocab = None POS_vocab = None NER_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allPOSs: {}'.format(len(allPOSs))) print('Number of allNERs: {}'.format(len(allNERs))) if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") if FLAGS.with_NER: NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build') NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ( 'ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_2( sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) if valid_graph.mode == 'evaluate': dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'. format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") else: dev_bleu = res_dict['dev_bleu'] print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'. format(best_bleu, dev_bleu)) saver.save(sess, best_path) best_bleu = dev_bleu FLAGS.best_bleu = dev_bleu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
def main_func(_): print(FLAGS) save_path = FLAGS.train_dir + "tfFile/" if not os.path.exists(save_path): os.makedirs(save_path) print("1. Loading WordVocab data...") wordVocab = Vocab() wordVocab.fromText_format3(FLAGS.train_dir, FLAGS.wordvec_path) sys.stdout.flush() prepare = Prepare() if FLAGS.hasTfrecords: print("2. Has Tfrecords File---Train---") total_lines = prepare.processTFrecords_hasDone(savePath=save_path, taskNumber=FLAGS.taskNumber) else: print("2. Start generating TFrecords File--train...") total_lines = prepare.processTFrecords(wordVocab, savePath=save_path, max_len=FLAGS.max_len, taskNumber=FLAGS.taskNumber) print("totalLines_train_0:", total_lines[0]) print("totalLines_train_1:", total_lines[1]) sys.stdout.flush() test_path = FLAGS.train_dir + FLAGS.test_path if FLAGS.hasTfrecords: print("3. Has TFrecords File--test...") totalLines_test = prepare.processTFrecords_test_hasDone(test_path=test_path, taskNumber=1) else: print("3. Start generating TFrecords File--test...") totalLines_test = prepare.processTFrecords_test(wordVocab, savePath=save_path, test_path=test_path, max_len=FLAGS.max_len, taskNumber=1) print("totalLines_test:", totalLines_test) sys.stdout.flush() print("4. Start loading TFrecords File...") taskNameList = [] for i in range(FLAGS.taskNumber): string = FLAGS.train_dir + 'tfFile/train-' + str(i) + '.tfrecords' taskNameList.append(string) print("taskNameList: ", taskNameList) sys.stdout.flush() ################ n = total_lines[0] / total_lines[1] + 1 if \ total_lines[0] % total_lines[1] != 0 else \ total_lines[0] / total_lines[1] print("n: ", n) num_batches_per_epoch_train_0 = int(total_lines[0] / FLAGS.batch_size) + 1 if \ total_lines[0] % FLAGS.batch_size != 0 else int( total_lines[0] / FLAGS.batch_size) print("batch_numbers_train_0:", num_batches_per_epoch_train_0) batch_size_1 = FLAGS.batch_size / n num_batches_per_epoch_test = int(totalLines_test / FLAGS.batch_size) + 1 if \ totalLines_test % FLAGS.batch_size != 0 else \ int(totalLines_test / FLAGS.batch_size) print("batch_numbers_test:", num_batches_per_epoch_test) with tf.Graph().as_default(): all_test = prepare.read_records( taskname=save_path + "test-0.tfrecords", max_len=FLAGS.max_len, epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size) all_train_0 = prepare.read_records( taskname=taskNameList[0], max_len=FLAGS.max_len, epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size) all_train_1 = prepare.read_records( taskname=taskNameList[1], max_len=FLAGS.max_len, epochs=FLAGS.num_epochs, batch_size=batch_size_1) print("Loading Model...") sys.stdout.flush() session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = True sess = tf.Session(config=session_conf) with sess.as_default(): print("------------train model--------------") m_train = mtl_model.MTLModel(max_len=FLAGS.max_len, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, num_hidden=FLAGS.num_hidden, word_vocab=wordVocab, l2_reg_lambda=FLAGS.l2_reg_lambda, learning_rate=FLAGS.learning_rate, adv=FLAGS.adv) m_train.build_train_op() print("\n\n") saver = tf.train.Saver(tf.global_variables(), max_to_keep=20) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) has_pre_trained_model = False out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, "runs")) print(out_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) else: print("continue training models") ckpt = tf.train.get_checkpoint_state(out_dir) if ckpt and ckpt.model_checkpoint_path: print("-------has_pre_trained_model--------") print(ckpt.model_checkpoint_path) has_pre_trained_model = True sys.stdout.flush() checkpoint_prefix = os.path.join(out_dir, "model") if has_pre_trained_model: print("Restoring model from " + ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) print("DONE!") sys.stdout.flush() def dev_whole(num_batches_per_epoch_test): accuracies = [] losses = [] for j in range(num_batches_per_epoch_test): input_y_test, input_left_test, input_centre_test = sess.run( [all_test[0], all_test[1], all_test[2]]) loss, accuracy, loss_adv, loss_ce = sess.run( [m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2], m_train.tensors[1][3]], feed_dict={ m_train.input_task_0: 0, m_train.input_left_0: input_left_real_0, m_train.input_right_0: input_centre_real_0, m_train.input_y_0: input_y_real_0, m_train.dropout_keep_prob: FLAGS.dropout_keep_prob, m_train.input_task_1: 1, m_train.input_left_1: input_left_test, m_train.input_right_1: input_centre_test, m_train.input_y_1: input_y_test, }) losses.append(loss_ce) accuracies.append(accuracy) # print("specfic_prob: ", prob_test) sys.stdout.flush() return np.mean(np.array(losses)), np.mean(np.array(accuracies)) def overfit(dev_accuracy): n = len(dev_accuracy) if n < 4: return False for i in range(n - 4, n): if dev_accuracy[i] > dev_accuracy[i - 1]: return False return True dev_accuracy = [] total_train_loss = [] train_loss_0 = 0 train_loss_1 = 0 loss_task_0 = 0 loss_task_1 = 0 adv_0 = 0 adv_1 = 0 acc_1 = 0 count = 0 try: while not coord.should_stop(): ## for each epoch for i in range(num_batches_per_epoch_train_0 * FLAGS.num_epochs): ## for each batch input_y_real_0, input_left_real_0, input_centre_real_0 = sess.run([all_train_0[0], all_train_0[1], all_train_0[2]]) input_y_real_1, input_left_real_1, input_centre_real_1 = sess.run([all_train_1[0], all_train_1[1], all_train_1[2]]) # acc, loss, loss_adv = m_train.tensors[0] # _, current_step_0, loss_0, accuracy_0, loss_adv_0 = sess.run( # [m_train.train_ops[0][0], m_train.train_ops[0][1], # m_train.tensors[0][1], m_train.tensors[0][0], m_train.tensors[0][2]], # feed_dict={ # m_train.input_task_0: 0, # m_train.input_left_0: input_left_real_0, # m_train.input_right_0: input_centre_real_0, # m_train.input_y_0: input_y_real_0, # m_train.dropout_keep_prob: FLAGS.dropout_keep_prob, # m_train.input_task_1: 1, # m_train.input_left_1: input_left_real_1, # m_train.input_right_1: input_centre_real_1, # m_train.input_y_1: input_y_real_1, # }) # all_loss_adv += loss_adv_0 # train_acc += accuracy_0 # train_loss_0 += loss_0 # train_loss += loss_0 # # _, current_step_1, loss_1, accuracy_1, loss_adv_1 = sess.run( # [m_train.train_ops[1][0], m_train.train_ops[1][1], # m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2]], # feed_dict={ # m_train.input_task_0: 0, # m_train.input_left_0: input_left_real_0, # m_train.input_right_0: input_centre_real_0, # m_train.input_y_0: input_y_real_0, # m_train.dropout_keep_prob: FLAGS.dropout_keep_prob, # m_train.input_task_1: 1, # m_train.input_left_1: input_left_real_1, # m_train.input_right_1: input_centre_real_1, # m_train.input_y_1: input_y_real_1, # }) _, loss_0, accuracy_0, loss_adv_0, loss_ce_0 = sess.run( [m_train.train_ops[0], m_train.tensors[0][1], m_train.tensors[0][0], m_train.tensors[0][2], m_train.tensors[0][3]], feed_dict={ m_train.input_task_0: 0, m_train.input_left_0: input_left_real_0, m_train.input_right_0: input_centre_real_0, m_train.input_y_0: input_y_real_0, m_train.dropout_keep_prob: FLAGS.dropout_keep_prob, m_train.input_task_1: 1, m_train.input_left_1: input_left_real_1, m_train.input_right_1: input_centre_real_1, m_train.input_y_1: input_y_real_1, }) train_loss_0 += loss_0 loss_task_0 += loss_ce_0 adv_0 += loss_adv_0 _, loss_1, accuracy_1, loss_adv_1, loss_ce_1 = sess.run( [m_train.train_ops[1], m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2], m_train.tensors[1][3]], feed_dict={ m_train.input_task_0: 0, m_train.input_left_0: input_left_real_0, m_train.input_right_0: input_centre_real_0, m_train.input_y_0: input_y_real_0, m_train.dropout_keep_prob: FLAGS.dropout_keep_prob, m_train.input_task_1: 1, m_train.input_left_1: input_left_real_1, m_train.input_right_1: input_centre_real_1, m_train.input_y_1: input_y_real_1, }) train_loss_1 += loss_1 loss_task_1 += loss_ce_1 adv_1 += loss_adv_1 acc_1 += accuracy_1 count += 1 if count % 500 == 0: print("loss {}, acc {}".format(loss_0, accuracy_0)) print("--loss {}, acc {}, loss_adv {}, loss_ce {}--".format(loss_1, accuracy_1, loss_adv_1, loss_ce_1)) sys.stdout.flush() if count % num_batches_per_epoch_train_0 == 0 or \ count == num_batches_per_epoch_train_0 * FLAGS.num_epochs: print("train_0: ", count / num_batches_per_epoch_train_0, " epoch, train_loss_0:", train_loss_0, "loss_task_0: ", loss_task_0, "adv_0: ", adv_0) print( "train_1: ", count / num_batches_per_epoch_train_0, " epoch, train_loss_1: ", train_loss_1, "loss_task_1: ", loss_task_1, "adv_1: ", adv_1, "acc_1 : ", acc_1 / num_batches_per_epoch_train_0) total_train_loss.append(loss_task_1) train_loss_0 = 0 train_loss_1 = 0 loss_task_0 = 0 loss_task_1 = 0 adv_0 = 0 adv_1 = 0 acc_1 = 0 sys.stdout.flush() print("\n------------------Evaluation:-----------------------") _, accuracy = dev_whole(num_batches_per_epoch_test) dev_accuracy.append(accuracy) print("--------Recently dev accuracy:--------") print(dev_accuracy[-10:]) print("--------Recently loss_task_1:------") print(total_train_loss[-10:]) if overfit(dev_accuracy): print('-----Overfit!!----') break print("") sys.stdout.flush() # continue path = saver.save(sess, checkpoint_prefix, global_step=count) print("-------------------Saved model checkpoint to {}--------------------".format(path)) sys.stdout.flush() output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=[ 'task_1/prob']) for node in output_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] with tf.gfile.GFile(FLAGS.train_dir + "runs/mtlmodel_specfic.pb", "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph.\n" % len(output_graph_def.node)) except tf.errors.OutOfRangeError: print("Done") finally: print("--------------------------finally---------------------------") print("current_step:", count) coord.request_stop() coord.join(threads) sess.close()