def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''): devDataStream.reset() gen = [] ref = [] dev_loss = 0.0 dev_right = 0.0 dev_total = 0.0 for batch_index in xrange(devDataStream.get_num_batch()): # for each batch cur_batch = devDataStream.get_batch(batch_index) cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch) if valid_graph.mode == 'evaluate': accu_value, loss_value = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True) dev_loss += loss_value dev_right += accu_value dev_total += np.sum(cur_batch.sent_len) elif valid_graph.mode == 'evaluate_bleu': gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist()) ref.extend(cur_batch.sent_out.tolist()) else: assert False if valid_graph.mode == 'evaluate': return {'dev_loss':dev_loss, 'dev_accu':1.0*dev_right/dev_total, 'dev_right':dev_right, 'dev_total':dev_total, } else: return {'dev_bleu':document_bleu(valid_graph.word_vocab,gen,ref,suffix), }
def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''): devDataStream.reset() gen = [] ref = [] dev_loss = 0.0 dev_right1 = 0.0 dev_right2 = 0.0 dev_total = 0.0 for batch_index in xrange(devDataStream.get_num_batch()): # for each batch cur_batch = devDataStream.get_batch(batch_index) cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch) # file = open('../data/train_output.txt', 'a+') if valid_graph.mode == 'evaluate': dic = valid_graph.word_vocab.id2word accu1_value, loss_value, vocab_score1, vocab_score2, greedy_words, loss_weights = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True) accu2_value = accu1_value[1] accu1_value = accu1_value[0] # accu_value contains accu1 and accu2 dev_loss += loss_value dev_right1 += accu1_value dev_right2 += accu2_value dev_total += np.sum(cur_batch.sent_len) # with tf.Session() as sess: # # sess.run(tf.global_variables_initializer()) # vocab_score1 = sess.run([vocab_score1]) # greedy_words = sess.run([greedy_words]) # vocab_score2 = sess.run([vocab_score2]) # file.write("target: " + str(cur_batch.target_ref) + '\n') # file.write("out_seqs1: " + str(_values_to_words(vocab_score1, loss_weights, dic)) + '\n') # file.write("greedy_words: " + str(_values_to_words(greedy_words, loss_weights, dic)) + '\n') # file.write("out_seqs2: " + str(_values_to_words(vocab_score2, loss_weights, dic)) + '\n') elif valid_graph.mode == 'evaluate_bleu': gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist()) ref.extend(cur_batch.sent_out.tolist()) else: assert False if valid_graph.mode == 'evaluate': return {'dev_loss':dev_loss, 'dev_accu1':1.0*dev_right1/dev_total, 'dev_accu':1.0*dev_right2/dev_total, 'dev_right1':dev_right1, 'dev_right':dev_right2, 'dev_total':dev_total, } else: return {'dev_bleu':document_bleu(valid_graph.word_vocab,gen,ref,suffix), }
def fine_tune(sess, saver, FLAGS, log_file, ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu): print('=====Start the fine tuning.') sys.stdout.flush() max_steps = ftDataStream.get_num_batch() * 1 best_path = path_prefix + ".best.model" total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = ftDataStream.nextBatch() cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch) if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % ftDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) sys.stdout.flush() log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu, best_bleu) total_loss = 0.0 start_time = time.time() print('=====End the fine tuning.') sys.stdout.flush() return best_accu, best_bleu
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") word_vocab_enc = None word_vocab_dec = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape)) word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") allEdgelabels = set([line.strip().split()[0] \ for line in open(FLAGS.edgelabel_vocab, 'rU')]) edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab SRC size {}'.format(word_vocab_enc.vocab_size)) print('word vocab TGT size {}'.format(word_vocab_dec.vocab_size)) sys.stdout.flush() print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_from_fof(FLAGS.train_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_from_fof(FLAGS.test_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of test samples: {}'.format(len(testset))) max_node = max(trn_node, tst_node) max_in_neigh = max(trn_in_neigh, tst_in_neigh) max_out_neigh = max(trn_out_neigh, tst_out_neigh) max_sent = max(trn_sent, tst_sent) print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num)) print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len)) print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ('ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() for var in tf.trainable_variables(): print(var) vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) sys.stdout.flush() log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch) if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) total_loss += loss_value if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or (step != 0 and step%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) if valid_graph.mode == 'evaluate': dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") else: dev_bleu = res_dict['dev_bleu'] print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu)) saver.save(sess, best_path) best_bleu = dev_bleu FLAGS.best_bleu = dev_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
FLAGS = G2S_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Loading test set from {}.'.format(in_path)) testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1 if mode not in ( 'pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): batch_size = 1 devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab, char_vocab,
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file( FLAGS.train_path) print('Number of training samples: {}'.format(len(trainset))) print('Loading dev set.') devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file( FLAGS.test_path) print('Number of dev samples: {}'.format(len(devset))) if FLAGS.finetune_path != "": print('Loading finetune set.') ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file( FLAGS.finetune_path) print('Number of finetune samples: {}'.format(len(ftset))) else: ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0) max_node = max(trn_node, tst_node, ft_node) max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh) max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh) max_sent = max(trn_sent, tst_sent, ft_sent) print('Max node number: {}, while max allowed is {}'.format( max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format( max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format( max_out_neigh, FLAGS.max_out_neigh_num)) print('Max answer length: {}, truncated to {}'.format( max_sent, FLAGS.max_answer_len)) word_vocab = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) if ftset != None: ftDataStream = G2S_data_stream.G2SDataStream(ftset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) print('Number of instances in ftDataStream: {}'.format( ftDataStream.get_num_instance())) print('Number of batches in ftDataStream: {}'.format( ftDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ('ce_train', 'rl_train', 'transformer') valid_mode = 'evaluate' if FLAGS.mode in ( 'ce_train', 'transformer') else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") sys.stdout.flush() best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) sys.stdout.flush() log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode in ('ce_train', 'rl_train', 'transformer') and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_subsample( sess, cur_batch, FLAGS) elif FLAGS.mode in ('ce_train', 'rl_train', 'transformer'): loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 if ftset != None: best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file, ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu) else: best_accu, best_bleu = validate_and_save( sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu, best_bleu) start_time = time.time() log_file.close()
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_from_fof( FLAGS.train_path, FLAGS) else: trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_file( FLAGS.train_path, FLAGS) random.shuffle(trainset) devset = trainset[:200] trainset = trainset[200:] print('Number of training samples: {}'.format(len(trainset))) print('Number of dev samples: {}'.format(len(devset))) max_node = trn_node max_in_neigh = trn_in_neigh max_out_neigh = trn_out_neigh max_sent = trn_sent print('Max node number: {}, while max allowed is {}'.format( max_node, FLAGS.max_node_num)) print('Max parent number: {}, truncated to {}'.format( max_in_neigh, FLAGS.max_in_neigh_num)) print('Max children number: {}, truncated to {}'.format( max_out_neigh, FLAGS.max_out_neigh_num)) print('Max entity size: {}, truncated to {}'.format( max_sent, FLAGS.max_entity_size)) word_vocab = None char_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') else: print('Collecting vocabs.') (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') char_vocab = None if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=False) devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=False) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode='train') with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, char_vocab=char_vocab, options=FLAGS, mode='evaluate') initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs last_step = 0 total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() _, loss_value, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss / (step - last_step), duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss / (step - last_step), duration)) sys.stdout.flush() log_file.flush() last_step = step total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'. format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") json.dump(res_dict['data'], open(FLAGS.output_path, 'w')) duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
char_vocab = None POS_vocab = None if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) print('Loading test set from {}.'.format(in_path)) if hasattr(FLAGS, 'num_relations') == False: FLAGS.num_relations = 2 testset = G2S_data_stream.read_bionlp_file(in_path, in_dep_path, FLAGS) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') testDataStream = G2S_data_stream.G2SDataStream(FLAGS, testset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=False, isLoop=False, isSort=False) print('Number of instances in testDataStream: {}'.format( testDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format(
print('Loading vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape)) word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape)) edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Loading test set from {}.'.format(in_path)) if FLAGS.infile_format == 'fof': testset, _, _, _, _ = G2S_data_stream.read_amr_from_fof(in_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) else: testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path, FLAGS, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size=-1 if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): batch_size = 1 devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size) print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch()))
def main(_): log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading data.') FLAGS.num_relations = 2 trainset = G2S_data_stream.read_bionlp_file(FLAGS.train_path, FLAGS.train_dep_path, FLAGS) if FLAGS.dev_gen == 'shuffle': random.shuffle(trainset) elif FLAGS.dev_gen == 'last': trainset.reverse() N = int(len(trainset)*FLAGS.dev_percent) devset = trainset[:N] trainset = trainset[N:] print('Number of training samples: {}'.format(len(trainset))) print('Number of dev samples: {}'.format(len(devset))) print('Number of relations: {}'.format(FLAGS.num_relations)) word_vocab = None char_vocab = None POS_vocab = None edgelabel_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) else: print('Collecting vocabs.') all_words = set() all_chars = set() all_poses = set() all_edgelabels = set() G2S_data_stream.collect_vocabs(trainset, all_words, all_chars, all_poses, all_edgelabels) G2S_data_stream.collect_vocabs(devset, all_words, all_chars, all_poses, all_edgelabels) print('Number of words: {}'.format(len(all_words))) print('Number of chars: {}'.format(len(all_chars))) print('Number of poses: {}'.format(len(all_poses))) print('Number of edgelabels: {}'.format(len(all_edgelabels))) word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=all_chars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=all_poses, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") edgelabel_vocab = Vocab(voc=all_edgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = G2S_data_stream.G2SDataStream(FLAGS, trainset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=True, isLoop=True, isSort=True, is_training=True) devDataStream = G2S_data_stream.G2SDataStream(FLAGS, devset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) sys.stdout.flush() FLAGS.trn_bch_num = trainDataStream.get_num_batch() # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) init_scale = 0.01 with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab, FLAGS, mode='train') with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab, FLAGS, mode='evaluate') initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if abs(best_accu) < 1e-5: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, FLAGS)['dev_f1'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs last_step = 0 total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() _, _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True) total_loss += cur_loss if step % 100==0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss/(step-last_step), duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss/(step-last_step), duration)) sys.stdout.flush() log_file.flush() last_step = step total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, FLAGS) dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_f1'] dev_precision = res_dict['dev_precision'] dev_recall = res_dict['dev_recall'] print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev F1 = %.4f, P = %.4f, R = %.4f' % (dev_accu, dev_precision, dev_recall)) log_file.write('Dev F1 = %.4f, P = %.4f, R = %.4f\n' % (dev_accu, dev_precision, dev_recall)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() start_time = time.time() log_file.close()
# load vocabs print('Loading vocabs.') word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2') print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) char_vocab = None if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) print('Loading test set from {}.'.format(in_path)) if FLAGS.infile_format == 'fof': testset = G2S_data_stream.read_nary_from_fof(in_path, FLAGS, is_rev=False) testset_rev = G2S_data_stream.read_nary_from_fof(in_path, FLAGS, is_rev=True) else: testset = G2S_data_stream.read_nary_file(in_path, FLAGS, is_rev=False) testset_rev = G2S_data_stream.read_nary_file(in_path, FLAGS, is_rev=True) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1 devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab,
# load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = G2S_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2') print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape)) word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2') print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape)) print('Loading test set from {}.'.format(in_path)) if FLAGS.infile_format == 'fof': testset, _, _, _ = G2S_data_stream.read_amr_from_fof(in_path) else: testset, _, _, _ = G2S_data_stream.read_amr_file(in_path) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1 if mode not in ( 'pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): batch_size = 1 devDataStream = G2S_data_stream.G2SDataStream(testset,