Exemplo n.º 1
0
def run_eval(model, batcher):
  """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  model.build_graph() # build the graph
  saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
  sess = tf.Session(config=util.get_config())
  eval_dir = os.path.join(FLAGS.log_root, "eval_loss") # make a subdir of the root dir for eval data
  bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
  summary_writer = tf.summary.FileWriter(eval_dir)
  running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
  running_avg_ratio = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
  best_loss = None  # will hold the best loss achieved so far
  train_dir = os.path.join(FLAGS.log_root, "train")

  while True:
    ckpt_state = tf.train.get_checkpoint_state(train_dir)
    tf.logging.info('max_enc_steps: %d, max_dec_steps: %d', FLAGS.max_enc_steps, FLAGS.max_dec_steps)
    _ = util.load_ckpt(saver, sess) # load a new checkpoint
    batch = batcher.next_batch() # get the next batch

    # run eval on the batch
    t0=time.time()
    results = model.run_eval_step(sess, batch)
    t1=time.time()
    tf.logging.info('seconds for batch: %.2f', t1-t0)

    # print the loss and coverage loss to screen
    loss = results['loss']
    tf.logging.info('loss: %f', loss)
    train_step = results['global_step']

    tf.logging.info("pgen_avg: %f", results['p_gen_avg'])

    if FLAGS.coverage:
      tf.logging.info("coverage_loss: %f", results['coverage_loss'])

    if FLAGS.inconsistent_loss:
      tf.logging.info('inconsistent_loss: %f', results['inconsist_loss'])

    tf.logging.info("selector_loss: %f", results['selector_loss'])
    recall, ratio, _ = util.get_batch_ratio(batch.original_articles_sents, batch.original_extracts_ids, results['probs'])
    write_to_summary(ratio, 'SentSelector/select_ratio/recall=0.9', train_step, summary_writer)

    # add summaries
    summaries = results['summaries']
    summary_writer.add_summary(summaries, train_step)

    # calculate running avg loss
    running_avg_loss = util.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step, 'running_avg_loss')
    running_avg_ratio = util.calc_running_avg_loss(ratio, running_avg_ratio, summary_writer, train_step, 'running_avg_ratio')

    # If running_avg_loss is best so far, save this checkpoint (early stopping).
    # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
    if best_loss is None or running_avg_loss < best_loss:
      tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
      saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
      best_loss = running_avg_loss

    # flush the summary writer every so often
    if train_step % 100 == 0:
      summary_writer.flush()
Exemplo n.º 2
0
def run_training(model,
                 batcher,
                 sess_context_manager,
                 sv,
                 summary_writer,
                 pretrained_saver=None,
                 saver=None):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("starting run_training")
    ckpt_path = os.path.join(FLAGS.log_root, "train", "model.ckpt")

    with sess_context_manager as sess:
        if FLAGS.pretrained_selector_path:
            tf.logging.info('Loading pretrained selector model')
            _ = util.load_ckpt(pretrained_saver,
                               sess,
                               ckpt_path=FLAGS.pretrained_selector_path)

        for _ in range(FLAGS.max_train_iter):  # repeats until interrupted
            batch = batcher.next_batch()  # 一个batch 5个文章

            tf.logging.info('running training step...')
            t0 = time.time()
            results = model.run_train_step(sess, batch)
            print("run train step finish")
            t1 = time.time()
            tf.logging.info('seconds for training step: %.3f', t1 - t0)

            loss = results['loss']
            tf.logging.info('loss: %f', loss)  # print the loss to screen

            if not np.isfinite(loss):
                raise Exception("Loss is not finite. Stopping.")

            train_step = results[
                'global_step']  # we need this to update our running average loss

            recall, ratio, _ = util.get_batch_ratio(batch.original_articles_sents, \
                                                    batch.original_extracts_ids, results['probs'])
            write_to_summary(ratio, 'SentSelector/select_ratio/recall=0.9',
                             train_step, summary_writer)

            # get the summaries and iteration number so we can write summaries to tensorboard
            summaries = results[
                'summaries']  # we will write these summaries to tensorboard using summary_writer
            summary_writer.add_summary(summaries,
                                       train_step)  # write the summaries
            if train_step % 100 == 0:  # flush the summary writer every so often
                summary_writer.flush()

            if train_step % FLAGS.save_model_every == 0:
                if FLAGS.pretrained_selector_path:
                    saver.save(sess, ckpt_path, global_step=train_step)
                else:
                    sv.saver.save(sess, ckpt_path, global_step=train_step)

            print('Step: ', train_step)