def convert_linear_attn_to_hier_model():
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting linear model to hier model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=util.get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([
        v for v in tf.global_variables()
        if "Linear--Section-Features" not in v.name and "v_sec" not in v.name
        and "Adagrad" not in v.name
    ])
    print("restoring variables...")
    curr_ckpt = util.load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt
    print(("saving model to %s..." % (new_fname)))
    new_saver = tf.train.Saver(
    )  # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()
def convert_to_coverage_model():
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting non-coverage model to coverage model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=util.get_config())
    if FLAGS.debug:
        print('entering debug mode')
        sess = tf_debug.LocalCLIDebugWrapperSession(sess,
                                                    ui_type=FLAGS.ui_type)
        sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([
        v for v in tf.global_variables()
        if "coverage" not in v.name and "Adagrad" not in v.name
    ])
    print("restoring non-coverage variables...")
    curr_ckpt = util.load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt + '_cov_init'
    print(("saving model to %s..." % (new_fname)))
    new_saver = tf.train.Saver(
    )  # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()
def restore_best_model():
    """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
    tf.logging.info("Restoring bestmodel for training...")

    # Initialize all vars in the model
    sess = tf.Session(config=util.get_config())
    print("Initializing all variables...")
    sess.run(tf.initialize_all_variables())

    # Restore the best model from eval dir
    saver = tf.train.Saver(
        [v for v in tf.all_variables() if "Adagrad" not in v.name])
    print("Restoring all non-adagrad variables from best model in eval dir...")
    curr_ckpt = util.load_ckpt(saver, sess, "eval")
    print("Restored %s." % curr_ckpt)

    # Save this model to train dir and quit
    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
    print("Saving model to %s..." % (new_fname))
    # this saver saves all variables that now exist, including Adagrad variables
    new_saver = tf.train.Saver()
    new_saver.save(sess, new_fname)
    print("Saved.")
    exit()
Beispiel #4
0
    def __init__(self, model, batcher, vocab):
        """Initialize decoder.

        Args:
          model: a Seq2SeqAttentionModel object.
          batcher: a Batcher object.
          vocab: Vocabulary object
        """
        self._model = model
        self._model.build_graph()
        self._batcher = batcher
        self._vocab = vocab
        self._saver = tf.train.Saver(
        )  # we use this to load checkpoints for decoding
        self._sess = tf.Session(config=util.get_config())

        # Load an initial checkpoint to use for decoding
        decode_checkpoint = FLAGS.decode_checkpoint if FLAGS.decode_checkpoint else None
        ckpt_path = util.load_ckpt(self._saver,
                                   self._sess,
                                   latest_filename=decode_checkpoint)

        if FLAGS.single_pass:
            # Make a descriptive decode directory name
            # this is something of the form "ckpt-123456"
            ckpt_name = "ckpt-" + ckpt_path.split('-')[-1]
            self._decode_dir = os.path.join(FLAGS.log_root,
                                            get_decode_dir_name(ckpt_name))
            # if os.path.exists(self._decode_dir):
            #     if not FLAGS.custom_decode_name:
            #         raise Exception(
            #             "single_pass decode directory %s should not already exist" % self._decode_dir)
            #     else:
            #         self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(
            #             ckpt_name)) + '_' + FLAGS.custom_decode_name

        else:  # Generic decode dir name
            self._decode_dir = os.path.join(FLAGS.log_root, "decode")

        # Make the decode dir if necessary
        if not os.path.exists(self._decode_dir):
            os.mkdir(self._decode_dir)

        if FLAGS.single_pass:
            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
            if not os.path.exists(self._rouge_dec_dir):
                os.mkdir(self._rouge_dec_dir)
def setup_training(model, batcher):
    """Does setup before starting training (run_training)"""
    train_dir = os.path.join(FLAGS.log_root, "train")
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)

    model.build_graph()  # build the graph

    if FLAGS.convert_to_coverage_model:
        assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
        convert_to_coverage_model()

    if FLAGS.convert_linear_to_hier_attn:
        convert_linear_attn_to_hier_model()

    if FLAGS.restore_best_model:
        restore_best_model()
    saver = tf.train.Saver(max_to_keep=3)  # keep 3 checkpoints at a time

    sv = tf.train.Supervisor(
        logdir=train_dir,
        is_chief=True,
        saver=saver,
        summary_op=None,
        save_summaries_secs=60,  # save summaries for tensorboard every 60 secs
        save_model_secs=60,  # checkpoint every 60 secs
        global_step=model.global_step)
    summary_writer = sv.summary_writer

    tf.logging.info("Preparing or waiting for session...")
    sess_context_manager = sv.prepare_or_wait_for_session(
        config=util.get_config())
    if FLAGS.debug:
        print('entering debug mode\n\n\n\n\n\n\n\n\n')
        sess_context_manager = tf_debug.LocalCLIDebugWrapperSession(
            sess_context_manager)
        sess_context_manager.add_tensor_filter("has_inf_or_nan",
                                               tf_debug.has_inf_or_nan)

    tf.logging.info("Created session.")
    try:
        # this is an infinite loop until interrupted
        run_training(model, batcher, sess_context_manager, sv, summary_writer)
    except KeyboardInterrupt:
        tf.logging.info(
            "Caught keyboard interrupt on worker. Stopping supervisor...")
        sv.stop()
def run_eval(model, batcher, vocab, hier=False):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    # we will keep 3 best checkpoints at a time
    saver = tf.train.Saver(max_to_keep=3)
    sess = tf.Session(config=util.get_config())
    # make a subdir of the root dir for eval data
    eval_dir = os.path.join(FLAGS.log_root, "eval")
    # this is where checkpoints of best models are saved
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel')
    summary_writer = tf.summary.FileWriter(eval_dir)
    # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    running_avg_loss = 0
    best_loss = None  # will hold the best loss achieved so far

    while True:
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        if FLAGS.coverage:
            coverage_loss = results['coverage_loss']
            tf.logging.info("coverage_loss: %f", coverage_loss)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        if hier:
            if np.isfinite(loss):
                running_avg_loss = calc_running_avg_loss(
                    np.asscalar(loss), running_avg_loss, summary_writer,
                    train_step)
            else:
                print(
                    'Warn: Loss nan, skipped one step in calculating average loss'
                )
                running_avg_loss = None
        else:
            running_avg_loss = calc_running_avg_loss(np.asscalar(loss),
                                                     running_avg_loss,
                                                     summary_writer,
                                                     train_step)

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or (running_avg_loss is not None
                                 and running_avg_loss < best_loss):
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()