コード例 #1
0
    def decode(self):
        t0 = time.time()
        counter = 0
        summary_file = open(self._summary_path, "w")
        while True:
            batch = self._batcher.next_batch()
            if batch is None:
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info("Output has been saved in %s and %s",
                                self._ref_dir, self._dec_dir)
                return

            arg_withunks = utils.show_abs_oovs(batch.original_arg[0],
                                               self._tgt_vocab, None)

            best_hyp_arg, best_hyp_kp = beam_search.run_beam_search(
                self._sess, self._model, self._tgt_vocab, batch)
            output_ids = [int(t) for t in best_hyp_arg.tokens[1:]]
            decoded_words = utils.outputids2words(output_ids, self._tgt_vocab,
                                                  None)
            try:
                fst_stop_idx = decoded_words.index(utils.ARG_STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            self.write_to_file(batch.original_arg_sents[0], decoded_words,
                               counter, "arg")

            summary_file.write("ID: %d\n" % counter)
            summary_file.write("OP: %s\n" % batch.original_src)
            summary_file.write("ARG: %s\n" % arg_withunks)
            summary_file.write("Generation: %s\n" % " ".join(decoded_words))
            summary_file.write("=" * 50 + "\n")

            if self._model.hps.model in ["sep_dec", "shd_dec"]:
                output_ids = [int(t) for t in best_hyp_kp.tokens[1:]]
                decoded_words = utils.outputids2words(output_ids,
                                                      self._tgt_vocab, None)
                try:
                    fst_stop_idx = decoded_words.index(utils.KP_STOP_DECODING)
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words

                self.write_to_file(batch.original_kp_sents[0], decoded_words,
                                   counter, "kp")
            counter += 1

        summary_file.close()
        tf.logging.info("Decoding took %.3f seconds", time.time() - t0)
コード例 #2
0
    def predict(self, text, tokenize=True, beam_search=True):
        """Generate summary.

        Args:
            text (str or list): Source.
            tokenize (bool, optional):
                Whether to do tokenize or not. Defaults to True.
            beam_search (bool, optional):
                Whether to use beam search or not.
                Defaults to True (means using greedy search).

        Returns:
            str: The final summary.
        """
        if isinstance(text, str) and tokenize:
            text = list(jieba.cut(text))
        x, oov = source2ids(text, self.vocab)
        x = torch.tensor(x).to(self.DEVICE)
        len_oovs = torch.tensor([len(oov)]).to(self.DEVICE)
        x_padding_masks = torch.ne(x, 0).byte().float()
        if beam_search:
            summary = self.beam_search(x.unsqueeze(0),
                                       max_sum_len=config.max_dec_steps,
                                       beam_width=config.beam_size,
                                       len_oovs=len_oovs,
                                       x_padding_masks=x_padding_masks)
        else:
            summary = self.greedy_search(x.unsqueeze(0),
                                         max_sum_len=config.max_dec_steps,
                                         len_oovs=len_oovs,
                                         x_padding_masks=x_padding_masks)
        summary = outputids2words(summary, oov, self.vocab)
        return summary.replace('<SOS>', '').replace('<EOS>', '').strip()
コード例 #3
0
    def run(self):

        counter = 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = utils.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(dataset.EOS_TOKEN)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._test_dir)