예제 #1
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = FLAGS.decode_after
        while True:
            tf.reset_default_graph()
            batch = self._batcher.next_batch()  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info("Decoder has finished reading dataset for single_pass.")
                tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir,
                                self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

            article_withunks = data.show_art_oovs(original_article, self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab,
                                                   (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            if FLAGS.ac_training:
                best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn,
                                                       self._dqn_sess, self._dqn_graph)
            else:
                best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self._vocab,
                                                 (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                self.write_for_rouge(original_abstract_sents, decoded_words,
                                     counter)  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks, decoded_output)  # log output to screen
                self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists,
                                       best_hyp.p_gens)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
                    t0 = time.time()
예제 #2
0
파일: decode.py 프로젝트: sra4077/RLSeq2Seq
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      if FLAGS.ac_training:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time()
예제 #3
0
  def evaluate(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0

    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_results, rouge_results_str = rouge_log(rouge_results_dict, self._decode_dir)
        t1 = time.time()
        tf.logging.info("evaluation time: %.3f min", (t1-t0)/60.0)
        return rouge_results, rouge_results_str

      if FLAGS.decode_method == 'greedy':
        output_ids = self._model.run_greedy_search(self._sess, batch)
        for i in range(FLAGS.batch_size):
          self.process_one_article(batch.original_articles_sents[i], batch.original_abstracts_sents[i], \
                                   batch.original_extracts_ids[i], output_ids[i], \
                                   batch.art_oovs[i], None, None, None, counter)
          counter += 1
      elif FLAGS.decode_method == 'beam':
        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]    # remove start token
        best_hyp.log_probs = best_hyp.log_probs[1:]   # remove start token probability
        self.process_one_article(batch.original_articles_sents[0], batch.original_abstracts_sents[0], \
                                 batch.original_extracts_ids[0], output_ids, batch.art_oovs[0], \
                                 best_hyp.attn_dists, best_hyp.p_gens, best_hyp.log_probs, counter)
        counter += 1
예제 #4
0
    def predict(self,input_data):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely,
         loading latest checkpoint at regular intervals
         """

        batch = input_data

        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(output_ids, self._vocab,
                                             (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words

        for i in range(len(decoded_words)):
            if not type(decoded_words[i]) is str:
                decoded_words[i] = str(decoded_words[i], encoding='utf-8')

        decoded_output = " ".join(decoded_words)  # single string
        return decoded_output
예제 #5
0
def run_test(use_cuda=False):
	max_dec_steps = 100
	min_dec_step = 50
	beam_size = 20
	vocab = Vocab(
		'/home/zhaoyuekai/torch_code/data/summary/finished_files/vocab',
		50000
		)
	hps = hyper_params(50, 100, True, 20, 'encode')
	dataloader = Batcher(
		"/home/zhaoyuekai/torch_code/data/summary/finished_files/chunked/train_*"
		, vocab, hps, single_pass=True)
	net = Summarization_Model(
		50000, 128, 256, 400, unif_mag=0.02,trunc_norm_std=1e-4, 
		use_coverage=False, pointer_gen=True
		)
	if use_cuda:
		net = net.cuda()

	step = 0
	while step < 2:
		batch = dataloader.next_batch()
		batch = batch2var(batch, use_cuda)
		h = run_beam_search(
				net, vocab, batch, beam_size, max_dec_steps, min_dec_step, 
				use_cuda
			)
		print(len(h.tokens))
		print(h.tokens)
		step += 1
예제 #6
0
    def decodeOneSample(self, batches):

        batch = batches[0]
        original_article = batch.original_articles[0]
        original_abstract = batch.original_abstracts[0]
        original_abstract_sents = batch.original_abstracts_sents[0]

        article_withunks = data.show_art_oovs(original_article, self._vocab)
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab,
                                               batch.art_oovs[0])

        best_hypothesis = beam_search.run_beam_search(self._session,
                                                      self._model, self._vocab,
                                                      batch, self._hps)

        output_ids = [int(t) for t in best_hypothesis.tokens[1:]]
        decoded_words = data.outputids2words(output_ids, self._vocab,
                                             batch.art_oovs[0])
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words

        decoded_output = ' '.join(decoded_words)  # single string

        self.write_for_rouge(original_abstract_sents, decoded_words, 0,
                             original_article)
        self.rouge_eval()
        print_results(article_withunks, abstract_withunks, decoded_output)
        self.write_for_attnvis(article_withunks, abstract_withunks,
                               decoded_words, best_hypothesis.attn_dists,
                               best_hypothesis.p_gens)
예제 #7
0
    def decode(self):

        decode_relations = []
        original_relations = []
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode

                # todo after finish final batch, compute precision, recall and f1_score
                p, r, f1 = calculate_measure_cmp(decode_relations,
                                                 original_relations)
                tf.logging.info("p: %.4f, r: %.4f, f1: %.4f", p, r, f1)

                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words_webnlg(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            decode_relations.append(decoded_output)
            original_relations.append(original_abstract)

            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            self.write_for_attnvis(
                article_withunks, abstract_withunks, decoded_words,
                best_hyp.attn_dists, best_hyp.p_gens
            )  # write info to .json file for visualization tool
  def _decode(self):
    """
    """
    t0 = time.time()
    counter = 0
    while True:
      batch = self.batcher._next_batch()
      if batch is None: 
        assert FLAGS.onetime, "Dataset exhausted, but we are not in onetime mode"
        print('INFO: Decoder has finished reading dataset for onetime.')
        print('INFO: Output has been saved in {} and {}, start ROUGE eval...'.format(self.rouge_ref_dir, self.rouge_dec_dir))
        results_dict = rouge_eval(self.rouge_ref_dir, self.rouge_dec_dir)
        rouge_log(results_dict, self.decode_dir)
        return

      original_article = batch.original_articles[0]
      original_abstract = batch.original_abstracts[0]

      article_withunks = data.show_art_oovs(original_article, self.vocab)
      abstract_withunks = data.show_abs_oovs(original_abstract, self.vocab, (batch.art_oovs[0] if FLAGS.pointer else None))

      best_hyp = beam_search.run_beam_search(self.sess, self.model, self.vocab, batch)

      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self.vocab, (batch.art_oovs[0] if FLAGS.pointer else None))

      try:
        fst_stop_idx = decoded_words.index(data.DECODING_END)
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words)

      if FLAGS.onetime:
        self._write_for_rouge(original_abstract, decoded_words, counter)
        counter += 1
      else:
        print("")
        print('INFO: ARTICLE: {}'.format(article_withunks))
        print('INFO: REFERENCE SUMMARY: {}'.format(abstract_withunks))
        print('INFO: GENERATED SUMMARY: {}'.format(decoded_output))
        print("")
        self._write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.pointers)

        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          print('INFO: Decoding for {} seconds, loading new checkpoint'.format(t1-t0))
          while True:
            try:
              ckpt_state = tf.train.get_checkpoint_state(train_dir)
              print('INFO: Loading checkpoint {}'.format(ckpt_state.model_checkpoint_path))
              self.saver.restore(self.sess, ckpt_state.model_checkpoint_path)
              break
            except:
              print('ERROR: Failed to restore checkpoint: {}, sleep for {} secs'.format(train_dir, 10))
              time.sleep(10)
          t0 = time.time()
예제 #9
0
def generate_summary(spacy_article, ideal_summary_length_tokens=60):
    """
    Generates summary of the given article. Note that this is slow (~20 seconds on a single CPU).
    
    Args:
        spacy_article: Spacy-processed text. The model was trained on the output of
        doc.spacy_text(), so for best results the input here should also come from doc.spacy_text().
    
    Returns:
        Tuple of unicode summary of the text and scalar score of its quality. Score is approximately
        an average log-likelihood of the summary (so it is < 0.) and typically is in the range
        [-.2, -.5]. Summaries with scores below -.4 are usually not very good.
    """
    assert isinstance(spacy_article, Doc)

    # These imports are slow - lazy import.
    from batcher import Batch, Example
    from beam_search import run_beam_search
    from io_processing import process_article, process_output

    if _model is None:
        _load_model()

    # Handle short inputs
    article_tokens, _, orig_article_tokens = process_article(spacy_article)
    if len(article_tokens) <= ideal_summary_length_tokens:
        return spacy_article.text, 0.

    min_summary_length = min(10 + len(article_tokens) / 10,
                             2 * ideal_summary_length_tokens / 3)
    max_summary_length = min(10 + len(article_tokens) / 5,
                             3 * ideal_summary_length_tokens / 2)

    # Make input data
    example = Example(' '.join(article_tokens),
                      abstract='',
                      vocab=_vocab,
                      hps=_hps)
    batch = Batch([example] * _beam_size, _hps, _vocab)

    # Generate output
    hyp, score = run_beam_search(
        _sess,
        _model,
        _vocab,
        batch,
        _beam_size,
        max_summary_length,
        min_summary_length,
        _settings.trace_path,
    )

    # Extract the output ids from the hypothesis and convert back to words
    return process_output(hyp.token_strings[1:], orig_article_tokens), score
예제 #10
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    out_num=0
    summaries=[]
    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert self.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")

      ##I commented those lines
        # tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)
        print out_num
        return summaries

      original_article = batch.original_articles[0]  # string
      # I commented those lines
      # original_abstract = batch.original_abstracts[0]  # string
      # original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      # I commented this line
      # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch,self.beam_size,self.max_dec_steps,self.min_dec_steps)

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if self.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if self.single_pass:
        summaries.append(decoded_output)
        # open('s'+str(out_num)+'.txt','w').write(decoded_output)
        # with open('output'+str(out_num)+'.txt','w') as output:
          # output.write(original_article+'\n*******************************************\n\n'+decoded_output)
        out_num+=1
        print out_num
        #this line is commented by me
        # self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
예제 #11
0
    def decode(self):
        t0 = time.time()
        counter = 0
        summary_file = open(self._summary_path, "w")
        while True:
            batch = self._batcher.next_batch()
            if batch is None:
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info("Output has been saved in %s and %s",
                                self._ref_dir, self._dec_dir)
                return

            arg_withunks = utils.show_abs_oovs(batch.original_arg[0],
                                               self._tgt_vocab, None)

            best_hyp_arg, best_hyp_kp = beam_search.run_beam_search(
                self._sess, self._model, self._tgt_vocab, batch)
            output_ids = [int(t) for t in best_hyp_arg.tokens[1:]]
            decoded_words = utils.outputids2words(output_ids, self._tgt_vocab,
                                                  None)
            try:
                fst_stop_idx = decoded_words.index(utils.ARG_STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            self.write_to_file(batch.original_arg_sents[0], decoded_words,
                               counter, "arg")

            summary_file.write("ID: %d\n" % counter)
            summary_file.write("OP: %s\n" % batch.original_src)
            summary_file.write("ARG: %s\n" % arg_withunks)
            summary_file.write("Generation: %s\n" % " ".join(decoded_words))
            summary_file.write("=" * 50 + "\n")

            if self._model.hps.model in ["sep_dec", "shd_dec"]:
                output_ids = [int(t) for t in best_hyp_kp.tokens[1:]]
                decoded_words = utils.outputids2words(output_ids,
                                                      self._tgt_vocab, None)
                try:
                    fst_stop_idx = decoded_words.index(utils.KP_STOP_DECODING)
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words

                self.write_to_file(batch.original_kp_sents[0], decoded_words,
                                   counter, "kp")
            counter += 1

        summary_file.close()
        tf.logging.info("Decoding took %.3f seconds", time.time() - t0)
예제 #12
0
    def decode2(self, batcher):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = batcher.next_batch()  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            return decoded_output
예제 #13
0
    def decode_one_batch(self, batch, withRouge=True):
        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string
        original_abstract_sents = batch.original_abstracts_sents[
            0]  # list of strings
        original_uuid = batch.uuids[0]  # string

        article_withunks = data.show_art_oovs(original_article,
                                              self._vocab)  # string
        abstract_withunks = data.show_abs_oovs(
            original_abstract, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                               self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        if FLAGS.single_pass:
            # write ref summary and decoded summary to file, to eval with pyrouge later
            # self.write_for_rouge(original_abstract_sents, decoded_words,self._counter)
            self.write_for_flink(original_uuid, original_article,
                                 decoded_words, original_abstract_sents)
            self._counter += 1  # this is how many examples we've decoded
        else:
            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            self.write_for_attnvis(
                article_withunks, abstract_withunks, decoded_words,
                best_hyp.attn_dists, best_hyp.p_gens
            )  # write info to .json file for visualization tool
def decode_example(sess, model, vocab, batch, counter, hps):
    # Run beam search to get best Hypothesis
    best_hyp = beam_search.run_beam_search(sess, model, vocab, batch, counter,
                                           hps)

    # Extract the output ids from the hypothesis and convert back to words
    output_ids = [int(t) for t in best_hyp.tokens[1:]]
    decoded_words = data.outputids2words(
        output_ids, vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

    # Remove the [STOP] token from decoded_words, if necessary
    try:
        fst_stop_idx = decoded_words.index(
            data.STOP_DECODING)  # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
        decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words)  # single string
    return decoded_words, decoded_output, best_hyp
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        return

      original_article = batch.original_articles[0]  # string

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string

      # Run beam search to get best Hypothesis
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, decoded_output) # log output to screen
        
        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess)
          t0 = time.time()
예제 #16
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    
    t0 = time.time()
    counter = 0
    while True:
      batch = self._batcher.next_batch()

      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch) 
      # Run beam search to get best Hypothesis
      #best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)  ### I am here now @@@@@

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      #decoded_words = data.outputids2words(output_ids, self._vocab)
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
      print (decoded_words)
      try:
          fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
          decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
          decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string
      print (decoded_output)
      #### print actual test input when decoding
      #print (batch.enc_batch_field)
      original_abstract_sents = batch.original_summaries_sents[0]  # list of strings
      print (original_abstract_sents)
      if FLAGS.single_pass:
          self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
          counter += 1 # this is how many examples we've decoded
      #results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
      #rouge_log(results_dict, self._decode_dir)
    '''
예제 #17
0
    def run_beam_decoder(self,batch):
        original_article=batch.original_articles[0]
        original_abstract=batch.original_abstracts[0]

        article_withunks=data.show_art_oovs(original_article,self.vocab)
        abstract_withunks=data.show_abs_oovs(original_abstract,self.vocab,(batch.art_oovs[0] if self.hp.pointer_gen==True else None))

        best_list=beam_search.run_beam_search(self.sess,self.model,self.vocab,batch)
        print(best_list.tokens)
        output_ids=[int(t) for t in best_list.tokens[1:]]
        decoded_words=data.outputids2words(output_ids,self.vocab,(batch.art_oovs[0] if self.hp.pointer_gen==True else None))
        print(decoded_words)

        try:
            fst_stop_idx=decoded_words.index(data.STOP_DECODING)
            decoded_words=decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words=decoded_words
        decoded_output=' '.join(decoded_words)

        return article_withunks,abstract_withunks,decoded_output
예제 #18
0
def do(batches, model, vocab, sess, FLAGS, counter, i):
    print('work : {0}, len : {1}'.format(i, len(batches)))
    for batch in batches:
        article = batch.original_articles[0]
        original_abstract_sents = batch.original_abstracts_sents[
            0]  # list of strings
        best_hyps = beam_search.run_beam_search(sess, model, vocab, batch)
        output_ids = [int(t) for t in best_hyps.tokens[1:]]
        decoded_words = data.outputids2words_beam(
            output_ids, vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        write_for_rouge_beam(
            original_abstract_sents, decoded_words, article, counter,
            FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path
        )  # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1  # this is how many examples we've decoded
예제 #19
0
def run_beam_search_decode(model, examples_list, id2word, data, ckpt_dir):
    ones_reward = np.ones([FLAGS.batch_size])
    saver = tf.train.Saver()  # we use this to load checkpoints for decoding
    sess = tf.Session(config=util.get_config())
    # Load an initial checkpoint to use for decoding
    ckpt_path = util.load_ckpt(saver, sess, ckpt_dir=ckpt_dir)
    print('Finished loading model')
    ckpt_name = 'BS_' + data + "-ckpt_from-" + ckpt_dir
    ckpt_name = ckpt_name + ckpt_path.split('-')[-1] if ckpt_dir.split(
        '_'
    )[0] == 'eval' else ckpt_name  # this is something of the form "ckpt-123456"
    decode_dir = os.path.join(FLAGS.exp_name, ckpt_name)
    if os.path.exists(decode_dir):
        raise Exception(
            "single_pass decode directory %s should not already exist" %
            decode_dir)
    if not os.path.exists(decode_dir): os.mkdir(decode_dir)
    test_file_name = '{}/{}_results.txt'.format(decode_dir, data)
    f = open(test_file_name, 'w')
    for i in xrange(len(examples_list)):
        if i % 200 == 0:
            print(i)
        batch_test_plot, batch_test_ending, batch_test_ext_plot, batch_test_len_plot, batch_test_len_ending, test_max_plot_oovs, batch_plot_oovs, batch_plot_mask = examples_list[
            i]
        best_hyp = beam_search.run_beam_search(
            sess, model, id2word, batch_test_plot, batch_test_ending,
            batch_test_ext_plot, batch_test_len_plot, batch_test_len_ending,
            test_max_plot_oovs, batch_plot_mask, ones_reward)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        logit_words = result_ids2words(output_ids, id2word, batch_plot_oovs[0])
        if 'PAD' not in logit_words:
            logit_words.append('PAD')
        mask_logit = logit_words[0:logit_words.index('PAD')]
        logit_txt = ' '.join(mask_logit)
        new_line = '{}\n'.format(logit_txt)
        f.write(new_line)
    print('Finished writing results!')
예제 #20
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        # t0 = time.time()
        batch = self._batcher.next_batch()  # 1 example repeated across batch

        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string

        # input data
        article_withunks = data.show_art_oovs(original_article,
                                              self._vocab)  # string
        abstract_withunks = data.show_abs_oovs(
            original_abstract, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                               self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        # tf.logging.info('ARTICLE:  %s', article)
        #  tf.logging.info('GENERATED SUMMARY: %s', decoded_output)

        sys.stdout.write(decoded_output)
예제 #21
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                #results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
                #rouge_log(results_dict, self._decode_dir)
                return
            '''
      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
      '''
            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(
                self._sess, self._model, self._vocab,
                batch)  ### I am here now @@@@@

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            print(decoded_words)
            '''
예제 #22
0
    def decode(self, batches):
        counter = 0
        for batch in batches:
            if (counter < 10000):
                original_article = batch.original_articles[0]
                original_abstract = batch.original_abstracts[0]
                original_abstract_sents = batch.original_abstracts_sents[0]

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)
                abstract_withunks = data.show_abs_oovs(original_abstract,
                                                       self._vocab,
                                                       batch.art_oovs[0])

                best_hypothesis = beam_search.run_beam_search(
                    self._session, self._model, self._vocab, batch, self._hps)

                output_ids = [int(t) for t in best_hypothesis.tokens[1:]]
                decoded_words = data.outputids2words(output_ids, self._vocab,
                                                     batch.art_oovs[0])
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words

                decoded_output = ' '.join(decoded_words)  # single string

                self.write_for_rouge(original_abstract_sents, decoded_words,
                                     counter, original_article)
                counter += 1
            else:
                break

        self.rouge_eval()
예제 #23
0
    def decode(self):
        """
        Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode
        indefinitely, loading latest checkpoint at regular intervals.
        """
        counter = 0
        scores = []

        while True:
            batch = self._batcher.next_batch()  # 1 example repeated across batch
            if batch is None: # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info("Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir,
                    self._rouge_dec_dir,
                )
                tf.logging.info("Mean score: %s", sum(scores) / len(scores))
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string

            article_withunks = data.show_art_oovs(original_article, self._vocab) # string
            abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, batch.art_oovs[0]) # string

            # Run beam search to get best Hypothesis
            t_beam = time.time()
            best_hyp, best_score = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch, FLAGS.beam_size, FLAGS.max_dec_steps,
                FLAGS.min_dec_steps, FLAGS.trace_path
            )
            scores.append(best_score)
            tf.logging.info("Time to decode one example: %f", time.time() - t_beam)
            tf.logging.info("Mean score: %s", sum(scores) / len(scores))

            # Extract the output ids from the hypothesis and convert back to words
            decoded_words = best_hyp.token_strings[1:]

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words) # single string

            if FLAGS.single_pass:
                self.write_for_rouge(original_abstract, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1 # this is how many examples we've decoded
            else:
                # log output to screen
                print_results(
                    article_withunks, abstract_withunks, decoded_output, best_hyp, [best_score]
                )
                # write info to .json file for visualization tool
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists,
                    best_hyp.p_gens, best_hyp.log_probs
                )

                raw_input()
예제 #24
0
파일: decode.py 프로젝트: JoJoJun/PG_model
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely,
        loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        if FLAGS.decode_bleu:
            ref_file = os.path.join(self._bleu_dec_dir, "reference.txt")
            decoded_file = os.path.join(self._bleu_dec_dir, "decoded.txt")
            if os.path.exists(decoded_file):
                tf.logging.info('正在删除 %s', decoded_file)
                os.remove(decoded_file)
            if os.path.exists(ref_file):
                tf.logging.info('正在删除 %s', ref_file)
                os.remove(ref_file)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                if FLAGS.decode_rouge:
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    try:
                        t0 = time.time()
                        results_dict = rouge_eval(self._rouge_ref_dir,
                                                  self._rouge_dec_dir)
                        rouge_log(results_dict, self._decode_dir)
                        t1 = time.time()
                        tf.logging.info(
                            'calculate Rouge score cost %d seconds', t1 - t0)
                    except Exception as e:
                        tf.logging.error('计算ROUGE出错 %s', e)
                if FLAGS.decode_bleu:
                    ref_file = os.path.join(self._bleu_dec_dir,
                                            "reference.txt")
                    decoded_file = os.path.join(self._bleu_dec_dir,
                                                "decoded.txt")

                    t0 = time.time()
                    bleu, bleu1, bleu2, bleu3, bleu4 = calcu_bleu(
                        decoded_file, ref_file)
                    sys_bleu = sys_bleu_file(decoded_file, ref_file)
                    sys_bleu_perl = sys_bleu_perl_file(decoded_file, ref_file)
                    t1 = time.time()

                    tf.logging.info(bcolors.HEADER +
                                    '-----------BLEU SCORE-----------' +
                                    bcolors.ENDC)
                    tf.logging.info(
                        bcolors.OKGREEN + '%f \t %f \t %f \t %f \t %f' +
                        bcolors.ENDC, bleu, bleu1, bleu2, bleu3, bleu4)
                    tf.logging.info(
                        bcolors.OKGREEN + 'sys_bleu %f' + bcolors.ENDC,
                        sys_bleu)
                    tf.logging.info(
                        bcolors.OKGREEN + 'sys_bleu_perl %s' + bcolors.ENDC,
                        sys_bleu_perl)
                    tf.logging.info(bcolors.HEADER +
                                    '-----------BLEU SCORE-----------' +
                                    bcolors.ENDC)
                    tf.logging.info('calculate BLEU score cost %d seconds',
                                    t1 - t0)
                break

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ''.join(decoded_words)  # single string

            if FLAGS.single_pass:
                print_results(article_withunks, abstract_withunks,
                              decoded_output, counter)  # log output to screen
                if FLAGS.decode_rouge:
                    self.write_for_rouge(
                        original_abstract_sents, decoded_words, counter
                    )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.decode_bleu:
                    self.write_for_bleu(original_abstract_sents, decoded_words)
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
    def decode(self, output_dir=None):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        idx = 0

        # used to store values during decoding in a list each
        outut_str = ""
        beam_search_str = ""
        metadata = []

        # evaluate over a fixed number of test set
        while True:  #idx <=100 :

            print("[%d]" % idx)

            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch

            #      if idx < 11000:
            #          idx += 1
            #          continue

            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            #  Run beam search to get all the Hypothesis
            all_hyp = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch, counter,
                self._lm_model, self._lm_word2idx, self._lm_idx2word
            )  #TODO changed the method signature just to look at the outputs of beam search

            if FLAGS.save_values:
                for h in all_hyp:
                    output_ids = [int(t) for t in h.tokens[1:]]
                    search_str = str(
                        data.outputids2words(output_ids, self._vocab,
                                             (batch.art_oovs[0]
                                              if FLAGS.pointer_gen else None)))
                    beam_search_str += search_str
                beam_search_str += "\n"

            # Extract the get best Hypothesis
            best_hyp = all_hyp[0]

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            metadata.append(decoded_words)

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            ###########################
            #print best hyp statistics

            hyp_stat = ""

            # log prob
            hyp_stat += "\navg log prob: %s.\n" % best_hyp.avg_log_prob

            # words overlap with article: this is buggy
            #      tri, bi, uni = word_overlap.gram_search(ngrams(nltk.pos_tag(article_withunks.strip().split()), 3), ngrams(nltk.pos_tag(decoded_output.strip().split()), 3))
            #      hyp_stat += "trigram overlap: %s. bigram overlap: %s. unigram overlap: %s.\n"%(uni, bi, tri)

            print_statistics.get_overlap(article_withunks.strip(),
                                         decoded_output.strip(),
                                         match_count=self.overlap_dict)
            hyp_stat += "word overlap: "
            for key, value in self.overlap_dict.iteritems():
                hyp_stat += "\n%d-gram avg overlap: %d" % (key, value /
                                                           (counter + 1))

            # num sentences and avg length
            self.total_nsentence += len(decoded_output.strip().split(
                "."))  #sentences are seperated by "."
            self.total_length += len(decoded_output.strip().split())
            avg_nsentence, avg_length = self.total_nsentence / (
                counter + 1), self.total_length / (counter + 1)

            hyp_stat += "\nnum sentences: %s. avg len: %s.\n" % (avg_nsentence,
                                                                 avg_length)

            # entropy??
            if FLAGS.print_info:
                print(hyp_stat)
            ###########################

            # saves data into numpy files for analysis
            if FLAGS.save_values:
                save_decode_data.save_data_iteration(self._decode_dir, counter,
                                                     best_hyp)

            if FLAGS.single_pass:  #change to counter later
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                # writing all the output combined to a file
                if FLAGS.print_info:
                    output = '\nARTICLE:  %s\n REFERENCE SUMMARY: %s\n' 'GENERATED SUMMARY: %s\n' % (
                        article_withunks, abstract_withunks, decoded_output)
                    print(output)
                    outut_str += output
                # Leena: modifying this to save more stuff
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens, counter,
                    best_hyp.log_prob, best_hyp.avg_log_prob,
                    best_hyp.average_pgen)  #change to counter later
                counter += 1  # this is how many examples we've decoded

            else:  #Leena: I use the above condition so might have neglected making change to the below condition
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()

            idx += 1

        #Leena: saving entire output and beam output as a string to write to a file
        if FLAGS.save_values:
            save_decode_data.save_data_once(self._decode_dir,
                                            FLAGS.result_path, outut_str,
                                            beam_search_str, metadata)
예제 #26
0
def decode(test_path, rl):
    sess = tf.Session(config=get_config())
    if FLAGS.beam == True:
        FLAGS.batch_size = FLAGS.beam_size
    FLAGS.max_dec_steps = 1
    print('batch size ', FLAGS.batch_size)
    #if rl == False:
    summarizationModel = PointerNet(FLAGS, vocab)
    #elif rl==True:
    #    if FLAGS.gamma > 0:
    #        import rl_model_gamma
    #        summarizationModel = rl_model_gamma.RLNet(FLAGS, vocab)
    #    else:
    #        import rl_model
    #        summarizationModel = rl_model.RLNet(FLAGS, vocab)
    summarizationModel.build_graph()
    saver = tf.train.Saver()
    best_model = load_best_model(FLAGS.restore_path)
    print('best model : {0}'.format(best_model))
    saver.restore(sess, save_path=best_model)
    counter = 0
    batcher = Batcher(test_path,
                      vocab,
                      FLAGS,
                      single_pass=FLAGS.single_pass,
                      decode_after=FLAGS.decode_after)
    batches = batcher.fill_batch_queue(
        is_training=False)  # 1 example repeated across batch
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    if FLAGS.beam == False:
        for batch in batches:
            article = batch.original_articles[0]
            original_abstract_sents = batch.original_abstracts_sents  # list of strings
            #print('*****************start**************')
            best_hyps = beam_search.run_greedy_search(sess, summarizationModel,
                                                      vocab, batch)
            output_ids = [[int(t) for t in best_hyp.tokens[1:]]
                          for best_hyp in best_hyps]
            decoded_words = data.outputids2words_greedy(
                output_ids, vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            decoded_words = remove_stop_index(decoded_words, data)
            write_for_rouge_greedy(
                original_abstract_sents, decoded_words, article, counter,
                FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            counter += FLAGS.batch_size  # this is how many examples we've decoded
            print('counter ... ', counter)
            if counter % (5 * 64) == 0:
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    else:
        for batch in batches:
            article = batch.original_articles[0]
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings
            #print('*****************start**************')
            best_hyps = beam_search.run_beam_search(sess, summarizationModel,
                                                    vocab, batch)
            #print('best hyp : {0}'.format(best_hyp))
            output_ids = [int(t) for t in best_hyps.tokens[1:]]
            decoded_words = data.outputids2words_beam(
                output_ids, vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            #decoded_words = ' '.join(decoded_words)
            write_for_rouge_beam(
                original_abstract_sents, decoded_words, article, counter,
                FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            counter += 1  # this is how many examples we've decoded
            print('counter ... ', counter)
            if counter % 100 == 0:
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
예제 #27
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        if not FLAGS.generate:
            t0 = time.time()
            counter = 0
            while True:
                batch = self._batcher.next_batch(
                )  # 1 example repeated across batch
                if batch is None:  # finished decoding dataset in single_pass mode
                    assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                    tf.logging.info(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    results_dict = rouge_eval(self._rouge_ref_dir,
                                              self._rouge_dec_dir)
                    rouge_log(results_dict, self._decode_dir)
                    return

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                original_abstract_sents = batch.original_abstracts_sents[
                    0]  # list of strings

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string

                # Run beam search to get best Hypothesis
                best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                       self._vocab, batch)

                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                if FLAGS.single_pass:
                    # write ref summary and decoded summary to file, to eval with pyrouge later
                    self.write_for_rouge(original_abstract_sents,
                                         decoded_words, counter)
                    counter += 1  # this is how many examples we've decoded
                else:
                    print_results(article_withunks, abstract_withunks,
                                  decoded_output)  # log output to screen
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens
                    )  # write info to .json file for visualization tool

                    # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                    t1 = time.time()
                    if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                        tf.logging.info(
                            'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                            t1 - t0)
                        _ = util.load_ckpt(self._saver, self._sess)
                        t0 = time.time()
        # when generate=True
        else:
            counter = 0
            while True:
                batch = self._batcher.next_batch(
                )  # 1 example repeated across batch
                if batch is None:  # finished decoding dataset in single_pass mode
                    assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                    tf.logging.info(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    return

                original_article = batch.original_articles[0]  # string
                # original_abstract = batch.original_abstracts[0]  # string
                # original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

                # Run beam search to get best Hypothesis
                best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                       self._vocab, batch)

                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                counter += 1
                # log output to screen
                print(
                    "---------------------------------------------------------------------------"
                )
                tf.logging.info('ARTICLE:  %s', article_withunks)
                tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
                print(
                    "---------------------------------------------------------------------------"
                )

                # self.write_for_rouge(original_abstract_sents, decoded_words, counter)
                # Write to file
                decoded_sents = []
                while len(decoded_words) > 0:
                    try:
                        fst_period_idx = decoded_words.index(".")
                    except ValueError:  # there is text remaining that doesn't end in "."
                        fst_period_idx = len(decoded_words)
                    sent = decoded_words[:fst_period_idx +
                                         1]  # sentence up to and including the period
                    decoded_words = decoded_words[fst_period_idx +
                                                  1:]  # everything else
                    decoded_sents.append(' '.join(sent))

                # pyrouge calls a perl script that puts the data into HTML files.
                # Therefore we need to make our output HTML safe.
                decoded_sents = [make_html_safe(w) for w in decoded_sents]

                # Write to file
                result_file = os.path.join(self._result_dir,
                                           "%06d_summary.txt" % counter)

                with open(result_file, "w") as f:
                    for idx, sent in enumerate(decoded_sents):
                        f.write(sent) if idx == len(
                            decoded_sents) - 1 else f.write(sent + "\n")
예제 #28
0
    def decode(self):
        """
        Decode examples until data is exhausted (if FLAGS.single_pass) and return,
        or decode indefinitely, loading latest checkpoint at regular intervals
        """
        t0 = time.time()
        start_time = t0
        counter = 0
        while True:
            # 1 example repeated across batch
            batch = self._batcher.next_batch()
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass, using %d seconds.",
                    time.time() - start_time)
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                # todo: need to update for rouge
                # results_dict = rouge_eval(self._rouge_ref_dir,
                #                           self._rouge_dec_dir)
                # rouge_log(results_dict, self._decode_dir)
                return

            original_context = batch.original_contexts[0]  # string
            original_query = batch.original_querys[0]
            original_summarization = batch.original_summarizations[0]  # string
            # original_abstract_sents = batch.original_abstracts_sents[
            #    0]  # list of strings

            context_withunks = data.show_art_oovs(original_context,
                                                  self._vocab)
            abstract_withunks = data.show_abs_oovs(
                original_summarization, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            #  export_path = os.path.join(FLAGS.export_dir,str(FLAGS.export_version))
            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                # index of the (first) [STOP] symbol
                fst_stop_idx = decoded_words.index(data.MARK_EOS)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ''.join(decoded_words)  # single string

            if FLAGS.single_pass:
                # todo: need to check
                # write ref summary and decoded summary to file, to eval with pyrouge later
                self.write_result(original_context, original_summarization,
                                  decoded_words, counter)
                # self.write_for_eval(original_summarization, output_ids,
                #                     counter)
                counter += 1  # this is how many examples we've decoded
            else:
                # log output to screen
                print_results(context_withunks, abstract_withunks,
                              decoded_output)
                # write info to .json file for visualization tool
                self.write_for_attnvis(context_withunks, abstract_withunks,
                                       decoded_words, best_hyp.attn_dists)

                # Check if SECS_UNTIL_NEW_CKPT has elapsed;
                # if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
예제 #29
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
                # rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            original_topic = batch.original_topics[0]

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            all_sencent = []
            best_hyps = beam_search.run_beam_search(self._sess, self._model,
                                                    self._vocab, batch)
            put_ids = [int(t) for t in best_hyps[0].tokens[1:]]
            standard_words = data.outputids2words(
                put_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            # Extract the output ids from the hypothesis and convert back to words
            score = []
            for best_hyp in best_hyps:
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                tmp_decoded = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))
                all_sencent.extend(self.removes(tmp_decoded))
                all_sencent.append('\n@next\n')
                score.append(self.get_score(original_topic, tmp_decoded))
            all_sencent.extend(self.removes(standard_words))
            all_sencent.append('\n@next\n')
            max_index = score.index(max(score))
            put_ids = [int(t) for t in best_hyps[max_index].tokens[1:]]
            decoded_words = data.outputids2words(
                put_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            all_sencent.extend(self.removes(decoded_words))
            decoded_words = all_sencent
            # Remove the [STOP] token from decoded_words, if necessary

            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.topic_attn_dists, best_hyp.
                    p_gens)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
예제 #30
0
    def bs_decode(self, batcher):
        """Decode examples until data is exhausted (if self._hps.single_pass) and
        return, or decode indefinitely, loading latest checkpoint at regular
        intervals"""
        # t0 = time.time()
        if self._hps.single_pass:
            ref_file = os.path.join(self._rouge_ref_dir, "reference.txt")
            decoded_file = os.path.join(self._rouge_dec_dir, "decoded.txt")
            overview_file = os.path.join(self._decode_dir, "overview.txt")
            ref_f = open(ref_file, "a", 'utf-8')
            dec_f = open(decoded_file, "a", 'utf-8')
            ove_f = open(overview_file, "a", 'utf-8')
        counter = 0
        try:
            while True:
                # 1 example repeated across batch
                batch = batcher.next_batch()
                if batch is None:
                    # finished decoding dataset in single_pass mode
                    assert self._hps.single_pass, (
                        "Dataset exhausted, but we are not in single_pass mode"
                    )
                    print(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    if self._hps.single_pass:
                        ref_f.close()
                        dec_f.close()
                        ove_f.close()
                        return
                    # print(
                    #     "Output has been saved in %s and %s. \
                    #     Now starting ROUGE eval..." % (
                    #         self._rouge_ref_dir, self._rouge_dec_dir))
                    # results_dict = rouge_eval(
                    #     self._rouge_ref_dir, self._rouge_dec_dir)
                    # rouge_log(results_dict, self._decode_dir)

                _, _, best_hyps = beam_search.run_beam_search(
                    self._sess, self._model, self._vocab, batch)
                # is the beam_size here 1?
                outputs_ids = [[int(t) for t in hyp.tokens[1:]]
                               for hyp in best_hyps[0]]

                original_articles = batch.original_articles
                original_abstracts = batch.original_abstracts
                # original_abstract_sents = batch.original_abstracts_sents[0]
                # list of strings

                art_oovs = [
                    batch.art_oovs[i] for i in xrange(self._hps.batch_size)
                ]
                # articles_withunks = data.show_art_oovs(original_articles, self._vocab)
                # abstracts_withunks = data.show_abs_oovs(original_abstracts, self._vocab, art_oovs)

                # Run beam search to get best Hypothesis

                decoded_words_list = data.outputsids2words(
                    outputs_ids, self._vocab, art_oovs)
                # art_oovs[0] should be changed, batch size examples should be
                # concluded
                decoded_outputs = []

                # Remove the [STOP] token from decoded_words, if necessary
                for decoded_words in decoded_words_list:
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        pass
                    decoded_outputs.append(' '.join(decoded_words))

                if self._hps.single_pass:
                    # write ref summary and decoded summary to file, to eval with
                    # pyrouge later
                    # self.write_for_discriminator(
                    #     original_articles, original_abstracts, decoded_outputs)
                    counter += 1  # this is how many examples we've decoded
                    if counter % 10000 == 0:
                        print("Have decoded %s samples." %
                              (counter * FLAGS.batch_size))

                    for idx, sent in enumerate(original_abstracts):
                        ref_f.write(sent + "\n")
                    for idx, sent in enumerate(decoded_outputs):
                        dec_f.write(sent + "\n")
                    for artc, refe, hypo in zip(original_articles,
                                                original_abstracts,
                                                decoded_outputs):
                        ove_f.write("article: " + artc + "\n")
                        ove_f.write("reference: " + refe + "\n")
                        ove_f.write("hypothesis: " + hypo + "\n")
                        ove_f.write("\n")
                # else:
                #     print_results(articles_withunks, abstracts_withunks, decoded_outputs)
                #     # log output to screen
                #     self.write_for_attnvis(articles_withunks, abstracts_withunks,
                #                            decoded_words, best_hyps.attn_dists, best_hyps.p_gens)
                # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we
                # can load a new checkpoint
                # t1 = time.time()
                # if t1-t0 > SECS_UNTIL_NEW_CKPT:
                #     tf.logging.info(
                #         'We\'ve been decoding with same checkpoint for %i \
                #         seconds. Time to load new checkpoint',
                #         t1-t0)
                #     _ = gen_utils.load_ckpt(self._saver, self._sess) # NOQA
                #     t0 = time.time()
        except KeyboardInterrupt as exc:
            print(exc)
            print("Have decoded %s samples." % (counter * FLAGS.batch_size))
            ref_f.close()
            dec_f.close()
            ove_f.close()
예제 #31
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        self._sess.run(tf.assign(self._model.is_train, tf.constant(False, tf.bool)))

        gts, res, weights = {}, {}, {}
        examples = []
        # gts is references dict, res is generate results, weights is references scores list.
        for id, batch in tqdm(enumerate(self._batcher), desc='test'):  # 1 example repeated across batch

            original_query = batch.original_query
            original_description = batch.original_description  # string
            original_responses = batch.original_responses  # string

            # Run beam search to get best Hypothesis
            hyps= beam_search.run_beam_search(self._args, self._sess, self._model, self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            result = []
            count = 0
            for hyp in hyps:
                output_ids = [int(t) for t in hyp.tokens[1:]]
                decoded_words = vocabulary.outputids2words(output_ids, self._vocab,
                                                           (batch.art_oovs[0] if self._args.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(vocabulary.STOP_DECODING)  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string



                result.append(decoded_output)

            

            try:
                selected_response = result[0]
                selected_response = vocabulary.response2keywords(selected_response,self._vocab)
                selected_response = ' '.join(selected_response)
            except:
                selected_response = ""

            #gts[id] = original_responses
            #res[id] = [selected_response]
            #weights[id]= original_scores


            # write results to file.
            example = {
                'query': original_query,
                'decription': original_description,
                'responses': original_responses,
                'generate': result,
                'select_cmt': selected_response,
            }
            examples.append(example)

            if id >= 200:
                break

        #self.evaluate(gts, res, weights)
        result_file = os.path.join(self._decode_dir, 'results.json')
        with open(result_file, 'w', encoding='utf8',)as p:
            json.dump(examples, p, indent=2, ensure_ascii=False)