示例#1
0
    def prepare_evaluate(self, ckpt_path=None):
        # Load an initial checkpoint to use for decoding
        if FLAGS.mode == 'evalall':
            if FLAGS.load_best_eval_model:
                tf.logging.info('Loading best eval checkpoint')
                ckpt_path = util.load_ckpt(self._saver,
                                           self._sess,
                                           ckpt_dir='eval_' +
                                           FLAGS.eval_method)
            elif FLAGS.eval_ckpt_path:
                ckpt_path = util.load_ckpt(self._saver,
                                           self._sess,
                                           ckpt_path=FLAGS.eval_ckpt_path)
            else:
                tf.logging.info('Loading best train checkpoint')
                ckpt_path = util.load_ckpt(self._saver, self._sess)
        elif FLAGS.mode == 'eval':
            _ = util.load_ckpt(self._saver, self._sess,
                               ckpt_path=ckpt_path)  # load a new checkpoint

        if FLAGS.single_pass:
            # Make a descriptive decode directory name
            ckpt_name = "ckpt-" + ckpt_path.split('-')[
                -1]  # this is something of the form "ckpt-123456"
            self._decode_dir = os.path.join(FLAGS.log_root,
                                            get_decode_dir_name(ckpt_name))
            tf.logging.info('Save evaluation results to ' + self._decode_dir)
            if os.path.exists(self._decode_dir):
                if FLAGS.mode == 'eval':
                    return False  # The checkpoint has already been evaluated. Evaluate next one.
                else:
                    raise Exception(
                        "single_pass decode directory %s should not already exist"
                        % self._decode_dir)
        else:  # Generic decode dir name
            self._decode_dir = os.path.join(FLAGS.log_root, "decode")

        # Make the decode dir if necessary
        if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

        if FLAGS.single_pass:
            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
            if not os.path.exists(self._rouge_dec_dir):
                os.mkdir(self._rouge_dec_dir)
            if FLAGS.save_vis:
                self._rouge_vis_dir = os.path.join(self._decode_dir,
                                                   "visualize")
                if not os.path.exists(self._rouge_vis_dir):
                    os.mkdir(self._rouge_vis_dir)
            if FLAGS.save_pkl:
                self._result_dir = os.path.join(self._decode_dir, "result")
                if not os.path.exists(self._result_dir):
                    os.mkdir(self._result_dir)
        return True
def run_eval(model, batcher):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    eval_dir = os.path.join(FLAGS.log_root, "eval_loss")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far
    train_dir = os.path.join(FLAGS.log_root, "train")

    while True:
        ckpt_state = tf.train.get_checkpoint_state(train_dir)
        tf.logging.info('max_enc_steps: %d, max_dec_steps: %d', FLAGS.max_enc_steps, FLAGS.max_dec_steps)
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        train_step = results['global_step']

        tf.logging.info("pgen_avg: %f", results['p_gen_avg'])

        if FLAGS.coverage:
            tf.logging.info("coverage_loss: %f", results['coverage_loss'])

        # add summaries
        summaries = results['summaries']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_loss = util.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step,
                                                      'running_avg_loss')

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss,
                            bestmodel_save_path)
            saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
def convert_to_coverage_model():
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting non-coverage model to coverage model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=util.get_config())
    print "initializing everything..."
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
    print "restoring non-coverage variables..."
    curr_ckpt = util.load_ckpt(saver, sess)
    print "restored."

    # save this model and quit
    ckpt_path = os.path.join(FLAGS.log_root, "train", "model.ckpt_cov")
    step = curr_ckpt.split('-')[1]
    new_fname = ckpt_path + '-' + step + '-init'
    print "saving model to %s..." % (new_fname)
    new_saver = tf.train.Saver()  # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print "saved."
    exit()
示例#4
0
    def __init__(self, model, batcher, vocab):
        """Initialize decoder.

        Args:
          model: a SentSelector object.
          batcher: a Batcher object.
          vocab: Vocabulary object
        """
        # get the data split set
        if "train" in FLAGS.data_path:
            self._dataset = "train"
        elif "val" in FLAGS.data_path:
            self._dataset = "val"
        elif "test" in FLAGS.data_path:
            self._dataset = "test"
        else:
            raise ValueError(
                "FLAGS.data_path %s should contain one of train, val or test" %
                (FLAGS.data_path))

        # create the data loader
        self._batcher = batcher

        if FLAGS.eval_gt_rouge:  # no need to load model
            # Make a descriptive decode directory name
            self._decode_dir = os.path.join(FLAGS.log_root,
                                            'select_gt' + self._dataset)
            tf.logging.info('Save evaluation results to ' + self._decode_dir)
            if os.path.exists(self._decode_dir):
                raise Exception(
                    "single_pass decode directory %s should not already exist"
                    % self._decode_dir)

            # Make the decode dir
            os.makedirs(self._decode_dir)

            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_gt_dir = os.path.join(self._decode_dir, "gt_selected")
            if not os.path.exists(self._rouge_gt_dir):
                os.mkdir(self._rouge_gt_dir)
        else:
            self._model = model
            self._model.build_graph()
            self._vocab = vocab
            self._saver = tf.train.Saver(
            )  # we use this to load checkpoints for decoding
            self._sess = tf.Session(config=util.get_config())

            # Load an initial checkpoint to use for decoding
            if FLAGS.load_best_eval_model:
                tf.logging.info('Loading best eval checkpoint')
                ckpt_path = util.load_ckpt(self._saver,
                                           self._sess,
                                           ckpt_dir='eval')
            elif FLAGS.eval_ckpt_path:
                ckpt_path = util.load_ckpt(self._saver,
                                           self._sess,
                                           ckpt_path=FLAGS.eval_ckpt_path)
            else:
                tf.logging.info('Loading best train checkpoint')
                ckpt_path = util.load_ckpt(self._saver, self._sess)

            if FLAGS.single_pass:
                # Make a descriptive decode directory name
                ckpt_name = "ckpt-" + ckpt_path.split('-')[
                    -1]  # this is something of the form "ckpt-123456"
                decode_root_dir, decode_dir = get_decode_dir_name(
                    ckpt_name, self._dataset)
                self._decode_root_dir = os.path.join(FLAGS.log_root,
                                                     decode_root_dir)
                self._decode_dir = os.path.join(FLAGS.log_root,
                                                decode_root_dir, decode_dir)
                tf.logging.info('Save evaluation results to ' +
                                self._decode_dir)
                if os.path.exists(self._decode_dir):
                    raise Exception(
                        "single_pass decode directory %s should not already exist"
                        % self._decode_dir)
            else:  # Generic decode dir name
                self._decode_dir = os.path.join(FLAGS.log_root, "select")

            # Make the decode dir if necessary
            if not os.path.exists(self._decode_dir):
                os.makedirs(self._decode_dir)

            if FLAGS.single_pass:
                # Make the dirs to contain output written in the correct format for pyrouge
                self._rouge_ref_dir = os.path.join(self._decode_dir,
                                                   "reference")
                if not os.path.exists(self._rouge_ref_dir):
                    os.mkdir(self._rouge_ref_dir)
                self._rouge_dec_dir = os.path.join(self._decode_dir,
                                                   "selected")
                if not os.path.exists(self._rouge_dec_dir):
                    os.mkdir(self._rouge_dec_dir)
                if FLAGS.save_pkl:
                    self._result_dir = os.path.join(self._decode_dir,
                                                    "select_result")
                    if not os.path.exists(self._result_dir):
                        os.mkdir(self._result_dir)

                self._probs_pkl_path = os.path.join(self._decode_root_dir,
                                                    "probs.pkl")
                if not os.path.exists(self._probs_pkl_path):
                    self._make_probs_pkl = True
                else:
                    self._make_probs_pkl = False
                self._precision = []
                self._recall = []
                self._accuracy = []
                self._ratio = []
                self._select_sent_num = []
def run_training(model,
                 batcher,
                 sess_context_manager,
                 sv,
                 summary_writer,
                 pretrained_saver=None,
                 saver=None):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("starting run_training")
    ckpt_path = os.path.join(FLAGS.log_root, "train", "model.ckpt")

    with sess_context_manager as sess:
        if FLAGS.pretrained_selector_path:
            tf.logging.info('Loading pretrained selector model')
            _ = util.load_ckpt(pretrained_saver,
                               sess,
                               ckpt_path=FLAGS.pretrained_selector_path)

        for _ in range(FLAGS.max_train_iter):  # repeats until interrupted
            batch = batcher.next_batch()
            if not batch:
                tf.logging.info(
                    'training has finished - no more batches left...')
                return

            tf.logging.info('running training step...')
            t0 = time.time()
            results = model.run_train_step(sess, batch)
            t1 = time.time()
            tf.logging.info('seconds for training step: %.3f', t1 - t0)

            loss = results['loss']
            tf.logging.info('loss: %f', loss)  # print the loss to screen

            if not np.isfinite(loss):
                raise Exception("Loss is not finite. Stopping.")

            train_step = results[
                'global_step']  # we need this to update our running average loss

            recall, ratio, _ = util.get_batch_ratio(
                batch.original_articles_sents, batch.original_extracts_ids,
                results['probs'])
            write_to_summary(ratio, 'SentSelector/select_ratio/recall=0.9',
                             train_step, summary_writer)

            # get the summaries and iteration number so we can write summaries to tensorboard
            summaries = results[
                'summaries']  # we will write these summaries to tensorboard using summary_writer
            summary_writer.add_summary(summaries,
                                       train_step)  # write the summaries
            if train_step % 100 == 0:  # flush the summary writer every so often
                summary_writer.flush()

            if train_step % FLAGS.save_model_every == 0:
                if FLAGS.pretrained_selector_path:
                    saver.save(sess, ckpt_path, global_step=train_step)
                else:
                    sv.saver.save(sess, ckpt_path, global_step=train_step)

            print 'Step: ', train_step