Ejemplo n.º 1
0
    def decode(self, ckpt_file=None):
        FLAGS = self._FLAGS

        # load latest checkpoint
        misc_utils.load_ckpt(self._saver, self._sess, self._ckpt_dir,
                             ckpt_file)

        counter = 0
        f = open(self._decode_dir, "w")
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                break

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch, FLAGS)

            # Extract the output ids from the hypothesis and convert back to
            # words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                # index of the (first) [STOP] symbol
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            # write ref summary and decoded summary to file, to eval with
            # pyrouge later
            # self.write_for_rouge(original_abstract_sents, decoded_words, counter)
            processed = self.depreciated_processing(decoded_words)
            f.write(processed + "\n")

            counter += 1
            if counter % 100 == 0:
                print("%d sentences decoded" % counter)

        f.close()
def convert_to_coverage_model():
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting non-coverage model to coverage model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=utils.get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([
        v for v in tf.global_variables()
        if "coverage" not in v.name and "Adagrad" not in v.name
    ])
    print("restoring non-coverage variables...")
    curr_ckpt = utils.load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt + '_cov_init'
    print("saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver(
    )  # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()
def restore_best_model():
    """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
    tf.logging.info("Restoring bestmodel for training...")

    # Initialize all vars in the model
    sess = tf.Session(config=utils.get_config())
    print("Initializing all variables...")
    sess.run(tf.global_variables_initializer())

    # Restore the best model from eval dir
    saver = tf.train.Saver(
        [v for v in tf.all_variables() if "Adagrad" not in v.name])
    print("Restoring all non-adagrad variables from best model in eval dir...")
    curr_ckpt = utils.load_ckpt(saver, sess, "eval")
    print("Restored %s." % curr_ckpt)

    # Save this model to train dir and quit
    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(PARAMS.log_root, "train", new_model_name)
    print("Saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver(
    )  # this saver saves all variables that now exist, including Adagrad variables
    new_saver.save(sess, new_fname)
    print("Saved.")
    exit()
def run_eval(model, batcher, vocab):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=utils.get_config())
    eval_dir = os.path.join(
        PARAMS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far

    while True:
        _ = utils.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        if PARAMS.coverage:
            coverage_loss = results['coverage_loss']
            tf.logging.info("coverage_loss: %f", coverage_loss)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_loss = calc_running_avg_loss(np.asscalar(loss),
                                                 running_avg_loss,
                                                 summary_writer, train_step)

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
    def initialize_or_restore_session(self, ckpt_file=None):
        """Initialize or restore session

        Args:
            ckpt_file: directory to specific checkpoints
        """
        # restore from lastest_checkpoint or specific file
        with self._graph.as_default():
            self._sess = tf.Session(
                graph=self._graph, config=misc_utils.get_config())
            self._sess.run(tf.global_variables_initializer())

            if self._logdir or ckpt_file:
                # restore from lastest_checkpoint or specific file if provided
                misc_utils.load_ckpt(saver=self._saver,
                                     sess=self._sess,
                                     ckpt_dir=self._logdir,
                                     ckpt_file=ckpt_file)
                return
    def __init__(self, model, batcher, vocab):
        """Initialize decoder.

        Args:
          model: a Seq2SeqAttentionModel object.
          batcher: a Batcher object.
          vocab: Vocabulary object
        """
        self._model = model
        self._model.build_graph()
        self._batcher = batcher
        self._vocab = vocab
        self._saver = tf.train.Saver(
        )  # we use this to load checkpoints for decoding
        self._sess = tf.Session(config=util.get_config())

        # Load an initial checkpoint to use for decoding
        ckpt_path = util.load_ckpt(self._saver, self._sess)

        if hps.single_pass:
            # Make a descriptive decode directory name
            ckpt_name = "ckpt-" + ckpt_path.split('-')[
                -1]  # this is something of the form "ckpt-123456"
            self._decode_dir = os.path.join(hps.log_root,
                                            get_decode_dir_name(ckpt_name))
            if os.path.exists(self._decode_dir):
                raise Exception(
                    "single_pass decode directory %s should not already exist"
                    % self._decode_dir)

        else:  # Generic decode dir name
            self._decode_dir = os.path.join(hps.log_root, "decode")

        # Make the decode dir if necessary
        if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

        if hps.single_pass:
            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
            if not os.path.exists(self._rouge_dec_dir):
                os.mkdir(self._rouge_dec_dir)
        if hps.rouge_eval_only:
            if not hps.eval_path:
                raise Exception(
                    "Must specify path to folder containing decoded files for evaluation"
                )
            else:
                self._rouge_dec_dir = hps.eval_path
                if not os.path.exists(self._rouge_dec_dir):
                    raise Exception(
                        "Folder containing decoded files for evaluation does not exist!"
                    )
    def decode(self):
        """Decode examples until data is exhausted (if hps.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        if (hps.rouge_eval_only):
            tf.logging.info("ROUGE only mode, Starting ROUGE eval...",
                            self._rouge_ref_dir, self._rouge_dec_dir)
            results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
            rouge_log(results_dict, self._decode_dir)
            return

        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert hps.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if hps.copy_source else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = run_beam_search(self._sess, self._model, self._vocab,
                                       batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if hps.copy_source else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            if hps.single_pass:
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()