def prepare_evaluate(self, ckpt_path=None): # Load an initial checkpoint to use for decoding if FLAGS.mode == 'evalall': if FLAGS.load_best_eval_model: tf.logging.info('Loading best eval checkpoint') ckpt_path = util.load_ckpt(self._saver, self._sess, ckpt_dir='eval_' + FLAGS.eval_method) elif FLAGS.eval_ckpt_path: ckpt_path = util.load_ckpt(self._saver, self._sess, ckpt_path=FLAGS.eval_ckpt_path) else: tf.logging.info('Loading best train checkpoint') ckpt_path = util.load_ckpt(self._saver, self._sess) elif FLAGS.mode == 'eval': _ = util.load_ckpt(self._saver, self._sess, ckpt_path=ckpt_path) # load a new checkpoint if FLAGS.single_pass: # Make a descriptive decode directory name ckpt_name = "ckpt-" + ckpt_path.split('-')[ -1] # this is something of the form "ckpt-123456" self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) tf.logging.info('Save evaluation results to ' + self._decode_dir) if os.path.exists(self._decode_dir): if FLAGS.mode == 'eval': return False # The checkpoint has already been evaluated. Evaluate next one. else: raise Exception( "single_pass decode directory %s should not already exist" % self._decode_dir) else: # Generic decode dir name self._decode_dir = os.path.join(FLAGS.log_root, "decode") # Make the decode dir if necessary if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) if FLAGS.single_pass: # Make the dirs to contain output written in the correct format for pyrouge self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir) if FLAGS.save_vis: self._rouge_vis_dir = os.path.join(self._decode_dir, "visualize") if not os.path.exists(self._rouge_vis_dir): os.mkdir(self._rouge_vis_dir) if FLAGS.save_pkl: self._result_dir = os.path.join(self._decode_dir, "result") if not os.path.exists(self._result_dir): os.mkdir(self._result_dir) return True
def run_eval(model, batcher): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" model.build_graph() # build the graph saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time sess = tf.Session(config=util.get_config()) eval_dir = os.path.join(FLAGS.log_root, "eval_loss") # make a subdir of the root dir for eval data bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved summary_writer = tf.summary.FileWriter(eval_dir) running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping best_loss = None # will hold the best loss achieved so far train_dir = os.path.join(FLAGS.log_root, "train") while True: ckpt_state = tf.train.get_checkpoint_state(train_dir) tf.logging.info('max_enc_steps: %d, max_dec_steps: %d', FLAGS.max_enc_steps, FLAGS.max_dec_steps) _ = util.load_ckpt(saver, sess) # load a new checkpoint batch = batcher.next_batch() # get the next batch # run eval on the batch t0 = time.time() results = model.run_eval_step(sess, batch) t1 = time.time() tf.logging.info('seconds for batch: %.2f', t1 - t0) # print the loss and coverage loss to screen loss = results['loss'] tf.logging.info('loss: %f', loss) train_step = results['global_step'] tf.logging.info("pgen_avg: %f", results['p_gen_avg']) if FLAGS.coverage: tf.logging.info("coverage_loss: %f", results['coverage_loss']) # add summaries summaries = results['summaries'] summary_writer.add_summary(summaries, train_step) # calculate running avg loss running_avg_loss = util.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step, 'running_avg_loss') # If running_avg_loss is best so far, save this checkpoint (early stopping). # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir if best_loss is None or running_avg_loss < best_loss: tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') best_loss = running_avg_loss # flush the summary writer every so often if train_step % 100 == 0: summary_writer.flush()
def convert_to_coverage_model(): """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" tf.logging.info("converting non-coverage model to coverage model..") # initialize an entire coverage model from scratch sess = tf.Session(config=util.get_config()) print "initializing everything..." sess.run(tf.global_variables_initializer()) # load all non-coverage weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) print "restoring non-coverage variables..." curr_ckpt = util.load_ckpt(saver, sess) print "restored." # save this model and quit ckpt_path = os.path.join(FLAGS.log_root, "train", "model.ckpt_cov") step = curr_ckpt.split('-')[1] new_fname = ckpt_path + '-' + step + '-init' print "saving model to %s..." % (new_fname) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print "saved." exit()
def __init__(self, model, batcher, vocab): """Initialize decoder. Args: model: a SentSelector object. batcher: a Batcher object. vocab: Vocabulary object """ # get the data split set if "train" in FLAGS.data_path: self._dataset = "train" elif "val" in FLAGS.data_path: self._dataset = "val" elif "test" in FLAGS.data_path: self._dataset = "test" else: raise ValueError( "FLAGS.data_path %s should contain one of train, val or test" % (FLAGS.data_path)) # create the data loader self._batcher = batcher if FLAGS.eval_gt_rouge: # no need to load model # Make a descriptive decode directory name self._decode_dir = os.path.join(FLAGS.log_root, 'select_gt' + self._dataset) tf.logging.info('Save evaluation results to ' + self._decode_dir) if os.path.exists(self._decode_dir): raise Exception( "single_pass decode directory %s should not already exist" % self._decode_dir) # Make the decode dir os.makedirs(self._decode_dir) # Make the dirs to contain output written in the correct format for pyrouge self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) self._rouge_gt_dir = os.path.join(self._decode_dir, "gt_selected") if not os.path.exists(self._rouge_gt_dir): os.mkdir(self._rouge_gt_dir) else: self._model = model self._model.build_graph() self._vocab = vocab self._saver = tf.train.Saver( ) # we use this to load checkpoints for decoding self._sess = tf.Session(config=util.get_config()) # Load an initial checkpoint to use for decoding if FLAGS.load_best_eval_model: tf.logging.info('Loading best eval checkpoint') ckpt_path = util.load_ckpt(self._saver, self._sess, ckpt_dir='eval') elif FLAGS.eval_ckpt_path: ckpt_path = util.load_ckpt(self._saver, self._sess, ckpt_path=FLAGS.eval_ckpt_path) else: tf.logging.info('Loading best train checkpoint') ckpt_path = util.load_ckpt(self._saver, self._sess) if FLAGS.single_pass: # Make a descriptive decode directory name ckpt_name = "ckpt-" + ckpt_path.split('-')[ -1] # this is something of the form "ckpt-123456" decode_root_dir, decode_dir = get_decode_dir_name( ckpt_name, self._dataset) self._decode_root_dir = os.path.join(FLAGS.log_root, decode_root_dir) self._decode_dir = os.path.join(FLAGS.log_root, decode_root_dir, decode_dir) tf.logging.info('Save evaluation results to ' + self._decode_dir) if os.path.exists(self._decode_dir): raise Exception( "single_pass decode directory %s should not already exist" % self._decode_dir) else: # Generic decode dir name self._decode_dir = os.path.join(FLAGS.log_root, "select") # Make the decode dir if necessary if not os.path.exists(self._decode_dir): os.makedirs(self._decode_dir) if FLAGS.single_pass: # Make the dirs to contain output written in the correct format for pyrouge self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) self._rouge_dec_dir = os.path.join(self._decode_dir, "selected") if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir) if FLAGS.save_pkl: self._result_dir = os.path.join(self._decode_dir, "select_result") if not os.path.exists(self._result_dir): os.mkdir(self._result_dir) self._probs_pkl_path = os.path.join(self._decode_root_dir, "probs.pkl") if not os.path.exists(self._probs_pkl_path): self._make_probs_pkl = True else: self._make_probs_pkl = False self._precision = [] self._recall = [] self._accuracy = [] self._ratio = [] self._select_sent_num = []
def run_training(model, batcher, sess_context_manager, sv, summary_writer, pretrained_saver=None, saver=None): """Repeatedly runs training iterations, logging loss to screen and writing summaries""" tf.logging.info("starting run_training") ckpt_path = os.path.join(FLAGS.log_root, "train", "model.ckpt") with sess_context_manager as sess: if FLAGS.pretrained_selector_path: tf.logging.info('Loading pretrained selector model') _ = util.load_ckpt(pretrained_saver, sess, ckpt_path=FLAGS.pretrained_selector_path) for _ in range(FLAGS.max_train_iter): # repeats until interrupted batch = batcher.next_batch() if not batch: tf.logging.info( 'training has finished - no more batches left...') return tf.logging.info('running training step...') t0 = time.time() results = model.run_train_step(sess, batch) t1 = time.time() tf.logging.info('seconds for training step: %.3f', t1 - t0) loss = results['loss'] tf.logging.info('loss: %f', loss) # print the loss to screen if not np.isfinite(loss): raise Exception("Loss is not finite. Stopping.") train_step = results[ 'global_step'] # we need this to update our running average loss recall, ratio, _ = util.get_batch_ratio( batch.original_articles_sents, batch.original_extracts_ids, results['probs']) write_to_summary(ratio, 'SentSelector/select_ratio/recall=0.9', train_step, summary_writer) # get the summaries and iteration number so we can write summaries to tensorboard summaries = results[ 'summaries'] # we will write these summaries to tensorboard using summary_writer summary_writer.add_summary(summaries, train_step) # write the summaries if train_step % 100 == 0: # flush the summary writer every so often summary_writer.flush() if train_step % FLAGS.save_model_every == 0: if FLAGS.pretrained_selector_path: saver.save(sess, ckpt_path, global_step=train_step) else: sv.saver.save(sess, ckpt_path, global_step=train_step) print 'Step: ', train_step