def convert_linear_attn_to_hier_model(): """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" tf.logging.info("converting linear model to hier model..") # initialize an entire coverage model from scratch sess = tf.Session(config=util.get_config()) print("initializing everything...") sess.run(tf.global_variables_initializer()) # load all non-coverage weights from checkpoint saver = tf.train.Saver([ v for v in tf.global_variables() if "Linear--Section-Features" not in v.name and "v_sec" not in v.name and "Adagrad" not in v.name ]) print("restoring variables...") curr_ckpt = util.load_ckpt(saver, sess) print("restored.") # save this model and quit new_fname = curr_ckpt print(("saving model to %s..." % (new_fname))) new_saver = tf.train.Saver( ) # this one will save all variables that now exist new_saver.save(sess, new_fname) print("saved.") exit()
def convert_to_coverage_model(): """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" tf.logging.info("converting non-coverage model to coverage model..") # initialize an entire coverage model from scratch sess = tf.Session(config=util.get_config()) if FLAGS.debug: print('entering debug mode') sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) print("initializing everything...") sess.run(tf.global_variables_initializer()) # load all non-coverage weights from checkpoint saver = tf.train.Saver([ v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name ]) print("restoring non-coverage variables...") curr_ckpt = util.load_ckpt(saver, sess) print("restored.") # save this model and quit new_fname = curr_ckpt + '_cov_init' print(("saving model to %s..." % (new_fname))) new_saver = tf.train.Saver( ) # this one will save all variables that now exist new_saver.save(sess, new_fname) print("saved.") exit()
def restore_best_model(): """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory""" tf.logging.info("Restoring bestmodel for training...") # Initialize all vars in the model sess = tf.Session(config=util.get_config()) print("Initializing all variables...") sess.run(tf.initialize_all_variables()) # Restore the best model from eval dir saver = tf.train.Saver( [v for v in tf.all_variables() if "Adagrad" not in v.name]) print("Restoring all non-adagrad variables from best model in eval dir...") curr_ckpt = util.load_ckpt(saver, sess, "eval") print("Restored %s." % curr_ckpt) # Save this model to train dir and quit new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model") new_fname = os.path.join(FLAGS.log_root, "train", new_model_name) print("Saving model to %s..." % (new_fname)) # this saver saves all variables that now exist, including Adagrad variables new_saver = tf.train.Saver() new_saver.save(sess, new_fname) print("Saved.") exit()
def __init__(self, model, batcher, vocab): """Initialize decoder. Args: model: a Seq2SeqAttentionModel object. batcher: a Batcher object. vocab: Vocabulary object """ self._model = model self._model.build_graph() self._batcher = batcher self._vocab = vocab self._saver = tf.train.Saver( ) # we use this to load checkpoints for decoding self._sess = tf.Session(config=util.get_config()) # Load an initial checkpoint to use for decoding decode_checkpoint = FLAGS.decode_checkpoint if FLAGS.decode_checkpoint else None ckpt_path = util.load_ckpt(self._saver, self._sess, latest_filename=decode_checkpoint) if FLAGS.single_pass: # Make a descriptive decode directory name # this is something of the form "ckpt-123456" ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) # if os.path.exists(self._decode_dir): # if not FLAGS.custom_decode_name: # raise Exception( # "single_pass decode directory %s should not already exist" % self._decode_dir) # else: # self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name( # ckpt_name)) + '_' + FLAGS.custom_decode_name else: # Generic decode dir name self._decode_dir = os.path.join(FLAGS.log_root, "decode") # Make the decode dir if necessary if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) if FLAGS.single_pass: # Make the dirs to contain output written in the correct format for pyrouge self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
def setup_training(model, batcher): """Does setup before starting training (run_training)""" train_dir = os.path.join(FLAGS.log_root, "train") if not os.path.exists(train_dir): os.makedirs(train_dir) model.build_graph() # build the graph if FLAGS.convert_to_coverage_model: assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True" convert_to_coverage_model() if FLAGS.convert_linear_to_hier_attn: convert_linear_attn_to_hier_model() if FLAGS.restore_best_model: restore_best_model() saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time sv = tf.train.Supervisor( logdir=train_dir, is_chief=True, saver=saver, summary_op=None, save_summaries_secs=60, # save summaries for tensorboard every 60 secs save_model_secs=60, # checkpoint every 60 secs global_step=model.global_step) summary_writer = sv.summary_writer tf.logging.info("Preparing or waiting for session...") sess_context_manager = sv.prepare_or_wait_for_session( config=util.get_config()) if FLAGS.debug: print('entering debug mode\n\n\n\n\n\n\n\n\n') sess_context_manager = tf_debug.LocalCLIDebugWrapperSession( sess_context_manager) sess_context_manager.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) tf.logging.info("Created session.") try: # this is an infinite loop until interrupted run_training(model, batcher, sess_context_manager, sv, summary_writer) except KeyboardInterrupt: tf.logging.info( "Caught keyboard interrupt on worker. Stopping supervisor...") sv.stop()
def run_eval(model, batcher, vocab, hier=False): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" model.build_graph() # build the graph # we will keep 3 best checkpoints at a time saver = tf.train.Saver(max_to_keep=3) sess = tf.Session(config=util.get_config()) # make a subdir of the root dir for eval data eval_dir = os.path.join(FLAGS.log_root, "eval") # this is where checkpoints of best models are saved bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') summary_writer = tf.summary.FileWriter(eval_dir) # the eval job keeps a smoother, running average loss to tell it when to implement early stopping running_avg_loss = 0 best_loss = None # will hold the best loss achieved so far while True: _ = util.load_ckpt(saver, sess) # load a new checkpoint batch = batcher.next_batch() # get the next batch # run eval on the batch t0 = time.time() results = model.run_eval_step(sess, batch) t1 = time.time() tf.logging.info('seconds for batch: %.2f', t1 - t0) # print the loss and coverage loss to screen loss = results['loss'] tf.logging.info('loss: %f', loss) if FLAGS.coverage: coverage_loss = results['coverage_loss'] tf.logging.info("coverage_loss: %f", coverage_loss) # add summaries summaries = results['summaries'] train_step = results['global_step'] summary_writer.add_summary(summaries, train_step) # calculate running avg loss if hier: if np.isfinite(loss): running_avg_loss = calc_running_avg_loss( np.asscalar(loss), running_avg_loss, summary_writer, train_step) else: print( 'Warn: Loss nan, skipped one step in calculating average loss' ) running_avg_loss = None else: running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step) # If running_avg_loss is best so far, save this checkpoint (early stopping). # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir if best_loss is None or (running_avg_loss is not None and running_avg_loss < best_loss): tf.logging.info( 'Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') best_loss = running_avg_loss # flush the summary writer every so often if train_step % 100 == 0: summary_writer.flush()