Beispiel #1
0
  def process_one_article(self, original_article_sents, original_abstract_sents, \
                          original_selected_ids, output_ids, oovs, \
                          attn_dists, p_gens, log_probs, counter):
    # Remove the [STOP] token from decoded_words, if necessary
    decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words) # single string
    decoded_sents = data.words2sents(decoded_words)

    if FLAGS.single_pass:
      verbose = False if FLAGS.mode == 'eval' else True
      self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
      if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
        original_article = ' '.join(original_article_sents)
        original_abstract = ' '.join(original_abstract_sents)
        article_withunks = data.show_art_oovs(original_article, self._vocab) # string
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, \
                               attn_dists, p_gens, log_probs, counter, verbose)
      if FLAGS.save_pkl:
        self.save_result(original_article_sents, original_abstract_sents, \
                         original_selected_ids, decoded_sents, counter, verbose)
Beispiel #2
0
    def decodeOneSample(self, batches):

        batch = batches[0]
        original_article = batch.original_articles[0]
        original_abstract = batch.original_abstracts[0]
        original_abstract_sents = batch.original_abstracts_sents[0]

        article_withunks = data.show_art_oovs(original_article, self._vocab)
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab,
                                               batch.art_oovs[0])

        best_hypothesis = beam_search.run_beam_search(self._session,
                                                      self._model, self._vocab,
                                                      batch, self._hps)

        output_ids = [int(t) for t in best_hypothesis.tokens[1:]]
        decoded_words = data.outputids2words(output_ids, self._vocab,
                                             batch.art_oovs[0])
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words

        decoded_output = ' '.join(decoded_words)  # single string

        self.write_for_rouge(original_abstract_sents, decoded_words, 0,
                             original_article)
        self.rouge_eval()
        print_results(article_withunks, abstract_withunks, decoded_output)
        self.write_for_attnvis(article_withunks, abstract_withunks,
                               decoded_words, best_hypothesis.attn_dists,
                               best_hypothesis.p_gens)
Beispiel #3
0
	def decode(self):
		"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
		t0 = time.time()
		counter = 0
		while True:
			batch = self._batcher.next_batch()  # 1 example repeated across batch
#			if counter == 10:
#				batch = None
			if batch is None: # finished decoding dataset in single_pass mode
				d = [all_attn_dists,all_pgens,dec_words]
				output_fname = os.path.join(self._decode_dir, 'attn_dist_p_gens_data.pkl')
				pickle.dump(d,open(output_fname,'wb'))
				assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
				tf.logging.info("Decoder has finished reading dataset for single_pass.")
				tf.logging.info("Output has been saved in %s and %s. Now starting eval...", self._rouge_ref_dir, self._rouge_dec_dir)
				#Metrics here
				self.get_metrics(self._rouge_ref_dir,self._rouge_dec_dir)
				return

			original_article = batch.original_articles[0]  # string
			original_query = batch.original_queries[0]  # string
			original_abstract = batch.original_abstracts[0]  # string
			original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

			article_withunks = data.show_art_oovs(original_article, self._vocab) # string
			#query_withunks = data.show_art_oovs(original_article, self._vocab) # string
			abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

			# Run beam search to get best Hypothesis
			best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

			# Extract the output ids from the hypothesis and convert back to words
			output_ids = [int(t) for t in best_hyp.tokens[1:]]
			decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

			# Remove the [STOP] token from decoded_words, if necessary
			try:
				fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
				decoded_words = decoded_words[:fst_stop_idx]
			except ValueError:
				decoded_words = decoded_words
			decoded_output = ' '.join(decoded_words) # single string

			if FLAGS.single_pass:
				self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
				counter += 1 # this is how many examples we've decoded
				all_attn_dists.append(best_hyp.attn_dists)
				all_pgens.append(best_hyp.p_gens)
				dec_words.append(decoded_words)
#				self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens,count=counter)
			else:
				print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
				self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

				# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
				t1 = time.time()
				if t1-t0 > SECS_UNTIL_NEW_CKPT:
					tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
					_ = util.load_ckpt(self._saver, self._sess)
					t0 = time.time()
Beispiel #4
0
    def decode(self):

        decode_relations = []
        original_relations = []
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode

                # todo after finish final batch, compute precision, recall and f1_score
                p, r, f1 = calculate_measure_cmp(decode_relations,
                                                 original_relations)
                tf.logging.info("p: %.4f, r: %.4f, f1: %.4f", p, r, f1)

                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words_webnlg(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            decode_relations.append(decoded_output)
            original_relations.append(original_abstract)

            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            self.write_for_attnvis(
                article_withunks, abstract_withunks, decoded_words,
                best_hyp.attn_dists, best_hyp.p_gens
            )  # write info to .json file for visualization tool
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    step = 0
    while True:
      step += 1
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert self._model.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return


      t_start_decode = time.time()
      
      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if self._model.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      best_hyp = self._model.model_decode(batch, self._sess, self._vocab)
      print(best_hyp.tokens)
      #best_hyp = None

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if self._model.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if self._model.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          print('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_checkpoint(self._saver, self._sess)
          t0 = time.time()

      t_end_decode = time.time()
      print('decode {}-th batch requires {} seconds'.format(step, int(t_end_decode - t_start_decode)))
Beispiel #6
0
    def decode(self, ckpt_file=None):
        FLAGS = self._FLAGS

        # load latest checkpoint
        misc_utils.load_ckpt(self._saver, self._sess, self._ckpt_dir,
                             ckpt_file)

        counter = 0
        f = open(self._decode_dir, "w")
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                break

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch, FLAGS)

            # Extract the output ids from the hypothesis and convert back to
            # words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                # index of the (first) [STOP] symbol
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            # write ref summary and decoded summary to file, to eval with
            # pyrouge later
            # self.write_for_rouge(original_abstract_sents, decoded_words, counter)
            processed = self.depreciated_processing(decoded_words)
            f.write(processed + "\n")

            counter += 1
            if counter % 100 == 0:
                print("%d sentences decoded" % counter)

        f.close()
Beispiel #7
0
    def pair_wise_decode(self):
        f = os.path.join(FLAGS.data_path, "output.txt")
        outputfile = codecs.open(f, "w", "utf8")
        output_result = []
        list_of_reference = []
        while True:
            batch = self._batcher.next_pairwised_decode_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                logging.info("eval_finished")
                outputfile.close()
                break
            print(self._batcher.c_index)
            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            result = self.eval_one_batch(self._sess, self._model, self._vocab,
                                         batch)
            # result = self.eval_one_batch(self._sess, self._model, self._vocab, batch)

            for i, instance in enumerate(result):
                if i == len(batch.art_oovs):
                    break
                if i >= batch.real_length:
                    print("eval done with {} isntances".format(
                        len(output_result)))
                    break
                out_words = data.outputids2words(instance,
                                                 self._model._vocab_out,
                                                 batch.art_oovs[i])
                if data.STOP_DECODING in out_words:
                    out_words = out_words[:out_words.index(data.STOP_DECODING)]

                output_now = " ".join(out_words)
                output_result.append(output_now)
                # refer = " ".join(refer)

                refer = batch.original_abstracts[i].strip()
                list_of_reference.append([refer])

                outputfile.write(batch.original_articles[i] + '\t' +
                                 output_now + '\n')

        bleu = matrix.bleu_score(list_of_reference, output_result)
        acc = matrix.compute_acc(list_of_reference, output_result)

        print("bleu : {}   acc : {}".format(bleu, acc))
        return
  def _decode(self):
    """
    """
    t0 = time.time()
    counter = 0
    while True:
      batch = self.batcher._next_batch()
      if batch is None: 
        assert FLAGS.onetime, "Dataset exhausted, but we are not in onetime mode"
        print('INFO: Decoder has finished reading dataset for onetime.')
        print('INFO: Output has been saved in {} and {}, start ROUGE eval...'.format(self.rouge_ref_dir, self.rouge_dec_dir))
        results_dict = rouge_eval(self.rouge_ref_dir, self.rouge_dec_dir)
        rouge_log(results_dict, self.decode_dir)
        return

      original_article = batch.original_articles[0]
      original_abstract = batch.original_abstracts[0]

      article_withunks = data.show_art_oovs(original_article, self.vocab)
      abstract_withunks = data.show_abs_oovs(original_abstract, self.vocab, (batch.art_oovs[0] if FLAGS.pointer else None))

      best_hyp = beam_search.run_beam_search(self.sess, self.model, self.vocab, batch)

      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self.vocab, (batch.art_oovs[0] if FLAGS.pointer else None))

      try:
        fst_stop_idx = decoded_words.index(data.DECODING_END)
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words)

      if FLAGS.onetime:
        self._write_for_rouge(original_abstract, decoded_words, counter)
        counter += 1
      else:
        print("")
        print('INFO: ARTICLE: {}'.format(article_withunks))
        print('INFO: REFERENCE SUMMARY: {}'.format(abstract_withunks))
        print('INFO: GENERATED SUMMARY: {}'.format(decoded_output))
        print("")
        self._write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.pointers)

        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          print('INFO: Decoding for {} seconds, loading new checkpoint'.format(t1-t0))
          while True:
            try:
              ckpt_state = tf.train.get_checkpoint_state(train_dir)
              print('INFO: Loading checkpoint {}'.format(ckpt_state.model_checkpoint_path))
              self.saver.restore(self.sess, ckpt_state.model_checkpoint_path)
              break
            except:
              print('ERROR: Failed to restore checkpoint: {}, sleep for {} secs'.format(train_dir, 10))
              time.sleep(10)
          t0 = time.time()
Beispiel #9
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      if FLAGS.ac_training:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time()
Beispiel #10
0
    def decode2(self, batcher):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = batcher.next_batch()  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            return decoded_output
Beispiel #11
0
    def decode_one_batch(self, batch, withRouge=True):
        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string
        original_abstract_sents = batch.original_abstracts_sents[
            0]  # list of strings
        original_uuid = batch.uuids[0]  # string

        article_withunks = data.show_art_oovs(original_article,
                                              self._vocab)  # string
        abstract_withunks = data.show_abs_oovs(
            original_abstract, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                               self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        if FLAGS.single_pass:
            # write ref summary and decoded summary to file, to eval with pyrouge later
            # self.write_for_rouge(original_abstract_sents, decoded_words,self._counter)
            self.write_for_flink(original_uuid, original_article,
                                 decoded_words, original_abstract_sents)
            self._counter += 1  # this is how many examples we've decoded
        else:
            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            self.write_for_attnvis(
                article_withunks, abstract_withunks, decoded_words,
                best_hyp.attn_dists, best_hyp.p_gens
            )  # write info to .json file for visualization tool
    def run_beam_decoder(self,batch):
        original_article=batch.original_articles[0]
        original_abstract=batch.original_abstracts[0]

        article_withunks=data.show_art_oovs(original_article,self.vocab)
        abstract_withunks=data.show_abs_oovs(original_abstract,self.vocab,(batch.art_oovs[0] if self.hp.pointer_gen==True else None))

        best_list=beam_search.run_beam_search(self.sess,self.model,self.vocab,batch)
        print(best_list.tokens)
        output_ids=[int(t) for t in best_list.tokens[1:]]
        decoded_words=data.outputids2words(output_ids,self.vocab,(batch.art_oovs[0] if self.hp.pointer_gen==True else None))
        print(decoded_words)

        try:
            fst_stop_idx=decoded_words.index(data.STOP_DECODING)
            decoded_words=decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words=decoded_words
        decoded_output=' '.join(decoded_words)

        return article_withunks,abstract_withunks,decoded_output
def run_flip(model, batcher, vocab):
  """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  model.build_graph() # build the graph
  saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
  sess = tf.Session(config=util.get_config())

  while True:
    _ = util.load_ckpt(saver, sess) # load a new checkpoint
    batch = batcher.next_batch() # get the next batch

    # run eval on the batch
    t0=time.time()
    # results = model.run_eval_step(sess, batch)
    results = model.run_decode(sess, batch)
    t1=time.time()
    tf.logging.info('seconds for batch: %.2f', t1-t0)
    # output_ids = [int(t) for t in best_hyp.tokens[1:]]
    # decoded_words = data.outputids2words(output_ids, self._vocab, (
    # batch.art_oovs[0] if FLAGS.pointer_gen else None))
    stop_id = model._vocab.word2id('[STOP]')

    for b in range(len(batch.original_abstracts)):
      original_abstract = batch.original_abstracts[b] # string
      fst_stop_idx = np.where(batch.target_batch[b,:] == stop_id)[0]
      if len(fst_stop_idx):
        fst_stop_idx = fst_stop_idx[0]
      else:
        fst_stop_idx = len(batch.target_batch[b,:])

      abstract_withunks = data.show_abs_oovs(original_abstract, model._vocab, None)  # string
      tf.logging.info('REFERENCE SUMMARY: %s', abstract_withunks)

      output_ids = [int(x) for x in results['ids'][:, b, 0]]
      output_ids = output_ids[:fst_stop_idx]
      decoded_words = data.outputids2words(output_ids, model._vocab, None)
      decoded_output = ' '.join(decoded_words)
      tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
      loss = results['loss']
      tf.logging.info('loss %.2f max flip %d dec %d target %d', loss, max(output_ids), batch.dec_batch.max(), batch.target_batch.max())
Beispiel #14
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        # t0 = time.time()
        batch = self._batcher.next_batch()  # 1 example repeated across batch

        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string

        # input data
        article_withunks = data.show_art_oovs(original_article,
                                              self._vocab)  # string
        abstract_withunks = data.show_abs_oovs(
            original_abstract, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                               self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        # tf.logging.info('ARTICLE:  %s', article)
        #  tf.logging.info('GENERATED SUMMARY: %s', decoded_output)

        sys.stdout.write(decoded_output)
Beispiel #15
0
    def decode(self, batches):
        counter = 0
        for batch in batches:
            if (counter < 10000):
                original_article = batch.original_articles[0]
                original_abstract = batch.original_abstracts[0]
                original_abstract_sents = batch.original_abstracts_sents[0]

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)
                abstract_withunks = data.show_abs_oovs(original_abstract,
                                                       self._vocab,
                                                       batch.art_oovs[0])

                best_hypothesis = beam_search.run_beam_search(
                    self._session, self._model, self._vocab, batch, self._hps)

                output_ids = [int(t) for t in best_hypothesis.tokens[1:]]
                decoded_words = data.outputids2words(output_ids, self._vocab,
                                                     batch.art_oovs[0])
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words

                decoded_output = ' '.join(decoded_words)  # single string

                self.write_for_rouge(original_abstract_sents, decoded_words,
                                     counter, original_article)
                counter += 1
            else:
                break

        self.rouge_eval()
Beispiel #16
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        if not FLAGS.generate:
            t0 = time.time()
            counter = 0
            while True:
                batch = self._batcher.next_batch(
                )  # 1 example repeated across batch
                if batch is None:  # finished decoding dataset in single_pass mode
                    assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                    tf.logging.info(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    results_dict = rouge_eval(self._rouge_ref_dir,
                                              self._rouge_dec_dir)
                    rouge_log(results_dict, self._decode_dir)
                    return

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                original_abstract_sents = batch.original_abstracts_sents[
                    0]  # list of strings

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string

                # Run beam search to get best Hypothesis
                best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                       self._vocab, batch)

                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                if FLAGS.single_pass:
                    # write ref summary and decoded summary to file, to eval with pyrouge later
                    self.write_for_rouge(original_abstract_sents,
                                         decoded_words, counter)
                    counter += 1  # this is how many examples we've decoded
                else:
                    print_results(article_withunks, abstract_withunks,
                                  decoded_output)  # log output to screen
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens
                    )  # write info to .json file for visualization tool

                    # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                    t1 = time.time()
                    if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                        tf.logging.info(
                            'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                            t1 - t0)
                        _ = util.load_ckpt(self._saver, self._sess)
                        t0 = time.time()
        # when generate=True
        else:
            counter = 0
            while True:
                batch = self._batcher.next_batch(
                )  # 1 example repeated across batch
                if batch is None:  # finished decoding dataset in single_pass mode
                    assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                    tf.logging.info(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    return

                original_article = batch.original_articles[0]  # string
                # original_abstract = batch.original_abstracts[0]  # string
                # original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

                # Run beam search to get best Hypothesis
                best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                       self._vocab, batch)

                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                counter += 1
                # log output to screen
                print(
                    "---------------------------------------------------------------------------"
                )
                tf.logging.info('ARTICLE:  %s', article_withunks)
                tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
                print(
                    "---------------------------------------------------------------------------"
                )

                # self.write_for_rouge(original_abstract_sents, decoded_words, counter)
                # Write to file
                decoded_sents = []
                while len(decoded_words) > 0:
                    try:
                        fst_period_idx = decoded_words.index(".")
                    except ValueError:  # there is text remaining that doesn't end in "."
                        fst_period_idx = len(decoded_words)
                    sent = decoded_words[:fst_period_idx +
                                         1]  # sentence up to and including the period
                    decoded_words = decoded_words[fst_period_idx +
                                                  1:]  # everything else
                    decoded_sents.append(' '.join(sent))

                # pyrouge calls a perl script that puts the data into HTML files.
                # Therefore we need to make our output HTML safe.
                decoded_sents = [make_html_safe(w) for w in decoded_sents]

                # Write to file
                result_file = os.path.join(self._result_dir,
                                           "%06d_summary.txt" % counter)

                with open(result_file, "w") as f:
                    for idx, sent in enumerate(decoded_sents):
                        f.write(sent) if idx == len(
                            decoded_sents) - 1 else f.write(sent + "\n")
Beispiel #17
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely,
        loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        if FLAGS.decode_bleu:
            ref_file = os.path.join(self._bleu_dec_dir, "reference.txt")
            decoded_file = os.path.join(self._bleu_dec_dir, "decoded.txt")
            if os.path.exists(decoded_file):
                tf.logging.info('正在删除 %s', decoded_file)
                os.remove(decoded_file)
            if os.path.exists(ref_file):
                tf.logging.info('正在删除 %s', ref_file)
                os.remove(ref_file)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                if FLAGS.decode_rouge:
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    try:
                        t0 = time.time()
                        results_dict = rouge_eval(self._rouge_ref_dir,
                                                  self._rouge_dec_dir)
                        rouge_log(results_dict, self._decode_dir)
                        t1 = time.time()
                        tf.logging.info(
                            'calculate Rouge score cost %d seconds', t1 - t0)
                    except Exception as e:
                        tf.logging.error('计算ROUGE出错 %s', e)
                if FLAGS.decode_bleu:
                    ref_file = os.path.join(self._bleu_dec_dir,
                                            "reference.txt")
                    decoded_file = os.path.join(self._bleu_dec_dir,
                                                "decoded.txt")

                    t0 = time.time()
                    bleu, bleu1, bleu2, bleu3, bleu4 = calcu_bleu(
                        decoded_file, ref_file)
                    sys_bleu = sys_bleu_file(decoded_file, ref_file)
                    sys_bleu_perl = sys_bleu_perl_file(decoded_file, ref_file)
                    t1 = time.time()

                    tf.logging.info(bcolors.HEADER +
                                    '-----------BLEU SCORE-----------' +
                                    bcolors.ENDC)
                    tf.logging.info(
                        bcolors.OKGREEN + '%f \t %f \t %f \t %f \t %f' +
                        bcolors.ENDC, bleu, bleu1, bleu2, bleu3, bleu4)
                    tf.logging.info(
                        bcolors.OKGREEN + 'sys_bleu %f' + bcolors.ENDC,
                        sys_bleu)
                    tf.logging.info(
                        bcolors.OKGREEN + 'sys_bleu_perl %s' + bcolors.ENDC,
                        sys_bleu_perl)
                    tf.logging.info(bcolors.HEADER +
                                    '-----------BLEU SCORE-----------' +
                                    bcolors.ENDC)
                    tf.logging.info('calculate BLEU score cost %d seconds',
                                    t1 - t0)
                break

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ''.join(decoded_words)  # single string

            if FLAGS.single_pass:
                print_results(article_withunks, abstract_withunks,
                              decoded_output, counter)  # log output to screen
                if FLAGS.decode_rouge:
                    self.write_for_rouge(
                        original_abstract_sents, decoded_words, counter
                    )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.decode_bleu:
                    self.write_for_bleu(original_abstract_sents, decoded_words)
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
Beispiel #18
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    # while True:
    batch = self._batcher.next_batch()  # 1 example repeated across batch
    if batch is None: # finished decoding dataset in single_pass mode
      assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
      tf.logging.info("Decoder has finished reading dataset for single_pass.")
      tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
      results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
      rouge_log(results_dict, self._decode_dir)
      return

    original_article = batch.original_articles[0]  # string
    original_abstract = batch.original_abstracts[0]  # string
    original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

    article_withunks = data.show_art_oovs(original_article, self._vocab) # string
    abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
    article_withunks = re.sub(" ","",article_withunks)
    # Run beam search to get best Hypothesis
    best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

    # Extract the output ids from the hypothesis and convert back to words
    output_ids = [int(t) for t in best_hyp.tokens[1:]]
    decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

    # Remove the [STOP] token from decoded_words, if necessary
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    for w in range(0,len(decoded_words)):
      decoded_words[w]=str(decoded_words[w])
    decoded_output = ' '.join(decoded_words) # single string
    decoded_output = re.sub("[UNK+(+)+\ ]","",decoded_output)
    decoded_output = re.sub("\.","",decoded_output)
    if "2002" not in decoded_output:
      print("?")
      decoded_output = re.sub("2","",decoded_output)
    # decoded_output = re.sub(" ","",decoded_output)
    decoded_output = re.sub("\[","",decoded_output)
    decoded_output = re.sub("\]","",decoded_output)
    # decoded_output = re.sub(" ","",decoded_output)
    decoded_output = re.sub("[养生+_社会_+社会频道+光明网+(+)+【+】+组图+万象+时尚]","",decoded_output)
    # decoded_output = re.sub("组图","",decoded_output)
    # decoded_output = re.sub("(","",decoded_output)
    # decoded_output = re.sub(")","",decoded_output)
    # decoded_output = re.sub("【","",decoded_output)
    # decoded_output = re.sub("】","",decoded_output)
    decoded_output = decoded_output.strip(decoded_output[-1]+decoded_output[-2])
    decoded_words = decoded_output
    # print(type(decoded_output))
    print(decoded_output)

    if FLAGS.single_pass:
      self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
      counter += 1 # this is how many examples we've decoded
    else:
      print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
      self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

      # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
      t1 = time.time()
      if t1-t0 > SECS_UNTIL_NEW_CKPT:
        tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
        _ = util.load_ckpt(self._saver, self._sess)
        t0 = time.time()
    def decode(self, output_dir=None):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        idx = 0

        # used to store values during decoding in a list each
        outut_str = ""
        beam_search_str = ""
        metadata = []

        # evaluate over a fixed number of test set
        while True:  #idx <=100 :

            print("[%d]" % idx)

            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch

            #      if idx < 11000:
            #          idx += 1
            #          continue

            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            #  Run beam search to get all the Hypothesis
            all_hyp = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch, counter,
                self._lm_model, self._lm_word2idx, self._lm_idx2word
            )  #TODO changed the method signature just to look at the outputs of beam search

            if FLAGS.save_values:
                for h in all_hyp:
                    output_ids = [int(t) for t in h.tokens[1:]]
                    search_str = str(
                        data.outputids2words(output_ids, self._vocab,
                                             (batch.art_oovs[0]
                                              if FLAGS.pointer_gen else None)))
                    beam_search_str += search_str
                beam_search_str += "\n"

            # Extract the get best Hypothesis
            best_hyp = all_hyp[0]

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            metadata.append(decoded_words)

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            ###########################
            #print best hyp statistics

            hyp_stat = ""

            # log prob
            hyp_stat += "\navg log prob: %s.\n" % best_hyp.avg_log_prob

            # words overlap with article: this is buggy
            #      tri, bi, uni = word_overlap.gram_search(ngrams(nltk.pos_tag(article_withunks.strip().split()), 3), ngrams(nltk.pos_tag(decoded_output.strip().split()), 3))
            #      hyp_stat += "trigram overlap: %s. bigram overlap: %s. unigram overlap: %s.\n"%(uni, bi, tri)

            print_statistics.get_overlap(article_withunks.strip(),
                                         decoded_output.strip(),
                                         match_count=self.overlap_dict)
            hyp_stat += "word overlap: "
            for key, value in self.overlap_dict.iteritems():
                hyp_stat += "\n%d-gram avg overlap: %d" % (key, value /
                                                           (counter + 1))

            # num sentences and avg length
            self.total_nsentence += len(decoded_output.strip().split(
                "."))  #sentences are seperated by "."
            self.total_length += len(decoded_output.strip().split())
            avg_nsentence, avg_length = self.total_nsentence / (
                counter + 1), self.total_length / (counter + 1)

            hyp_stat += "\nnum sentences: %s. avg len: %s.\n" % (avg_nsentence,
                                                                 avg_length)

            # entropy??
            if FLAGS.print_info:
                print(hyp_stat)
            ###########################

            # saves data into numpy files for analysis
            if FLAGS.save_values:
                save_decode_data.save_data_iteration(self._decode_dir, counter,
                                                     best_hyp)

            if FLAGS.single_pass:  #change to counter later
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                # writing all the output combined to a file
                if FLAGS.print_info:
                    output = '\nARTICLE:  %s\n REFERENCE SUMMARY: %s\n' 'GENERATED SUMMARY: %s\n' % (
                        article_withunks, abstract_withunks, decoded_output)
                    print(output)
                    outut_str += output
                # Leena: modifying this to save more stuff
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens, counter,
                    best_hyp.log_prob, best_hyp.avg_log_prob,
                    best_hyp.average_pgen)  #change to counter later
                counter += 1  # this is how many examples we've decoded

            else:  #Leena: I use the above condition so might have neglected making change to the below condition
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()

            idx += 1

        #Leena: saving entire output and beam output as a string to write to a file
        if FLAGS.save_values:
            save_decode_data.save_data_once(self._decode_dir,
                                            FLAGS.result_path, outut_str,
                                            beam_search_str, metadata)
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
                # rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            original_topic = batch.original_topics[0]

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            all_sencent = []
            best_hyps = beam_search.run_beam_search(self._sess, self._model,
                                                    self._vocab, batch)
            put_ids = [int(t) for t in best_hyps[0].tokens[1:]]
            standard_words = data.outputids2words(
                put_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            # Extract the output ids from the hypothesis and convert back to words
            score = []
            for best_hyp in best_hyps:
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                tmp_decoded = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))
                all_sencent.extend(self.removes(tmp_decoded))
                all_sencent.append('\n@next\n')
                score.append(self.get_score(original_topic, tmp_decoded))
            all_sencent.extend(self.removes(standard_words))
            all_sencent.append('\n@next\n')
            max_index = score.index(max(score))
            put_ids = [int(t) for t in best_hyps[max_index].tokens[1:]]
            decoded_words = data.outputids2words(
                put_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            all_sencent.extend(self.removes(decoded_words))
            decoded_words = all_sencent
            # Remove the [STOP] token from decoded_words, if necessary

            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.topic_attn_dists, best_hyp.
                    p_gens)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
Beispiel #21
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                return

            original_reviews = batch.original_reviews[0]  # list
            original_answer = batch.original_answers[0]  # string
            original_answers_sents = batch.original_answers_sents[
                0]  # list of strings
            original_question = batch.original_questions[0]
            y_target = batch.y_target_batch[0]

            review_withunks = [
                data.show_art_oovs(original_review, self._vocab)
                for original_review in original_reviews
            ]  # string
            question_withunks = data.show_art_oovs(original_question,
                                                   self._vocab)  # string
            answer_withunks = data.show_abs_oovs(
                original_answer, self._vocab,
                (batch.oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp, y_pred = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                self.write_for_eval(
                    original_answers_sents, decoded_words, original_question,
                    y_target, y_pred, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(review_withunks, answer_withunks, decoded_output,
                              original_question, y_target,
                              y_pred)  # log output to screen
                self.write_for_attnvis(
                    question_withunks, review_withunks, decoded_words,
                    best_hyp.r_attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
Beispiel #22
0
    def decode(self):
        """
        Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode
        indefinitely, loading latest checkpoint at regular intervals.
        """
        counter = 0
        scores = []

        while True:
            batch = self._batcher.next_batch()  # 1 example repeated across batch
            if batch is None: # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info("Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir,
                    self._rouge_dec_dir,
                )
                tf.logging.info("Mean score: %s", sum(scores) / len(scores))
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string

            article_withunks = data.show_art_oovs(original_article, self._vocab) # string
            abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, batch.art_oovs[0]) # string

            # Run beam search to get best Hypothesis
            t_beam = time.time()
            best_hyp, best_score = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch, FLAGS.beam_size, FLAGS.max_dec_steps,
                FLAGS.min_dec_steps, FLAGS.trace_path
            )
            scores.append(best_score)
            tf.logging.info("Time to decode one example: %f", time.time() - t_beam)
            tf.logging.info("Mean score: %s", sum(scores) / len(scores))

            # Extract the output ids from the hypothesis and convert back to words
            decoded_words = best_hyp.token_strings[1:]

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words) # single string

            if FLAGS.single_pass:
                self.write_for_rouge(original_abstract, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1 # this is how many examples we've decoded
            else:
                # log output to screen
                print_results(
                    article_withunks, abstract_withunks, decoded_output, best_hyp, [best_score]
                )
                # write info to .json file for visualization tool
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists,
                    best_hyp.p_gens, best_hyp.log_probs
                )

                raw_input()
Beispiel #23
0
    def decode(self):
        """
        Decode examples until data is exhausted (if FLAGS.single_pass) and return,
        or decode indefinitely, loading latest checkpoint at regular intervals
        """
        t0 = time.time()
        start_time = t0
        counter = 0
        while True:
            # 1 example repeated across batch
            batch = self._batcher.next_batch()
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass, using %d seconds.",
                    time.time() - start_time)
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                #todo: need to update for rouge
                # results_dict = rouge_eval(self._rouge_ref_dir,
                #                           self._rouge_dec_dir)
                # rouge_log(results_dict, self._decode_dir)
                return

            original_title = batch.original_titles[0]  # string
            original_summarization = batch.original_summarizations[0]  # string
            #original_abstract_sents = batch.original_abstracts_sents[
            #    0]  # list of strings

            title_withunks = data.show_art_oovs(original_title, self._vocab)
            abstract_withunks = data.show_abs_oovs(
                original_summarization, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                # index of the (first) [STOP] symbol
                fst_stop_idx = decoded_words.index(data.MARK_EOS)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ''.join(decoded_words)  # single string

            if FLAGS.single_pass:
                #todo: need to check
                # write ref summary and decoded summary to file, to eval with pyrouge later
                self.write_result(original_title, original_summarization,
                                  decoded_words, counter)
                # self.write_for_eval(original_summarization, output_ids,
                #                     counter)
                counter += 1  # this is how many examples we've decoded
            else:
                # log output to screen
                print_results(title_withunks, abstract_withunks,
                              decoded_output)
                # write info to .json file for visualization tool
                self.write_for_attnvis(title_withunks, abstract_withunks,
                                       decoded_words, best_hyp.attn_dists)

                # Check if SECS_UNTIL_NEW_CKPT has elapsed;
                # if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
Beispiel #24
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            #if counter >= 2:
            #  batch = None
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                if FLAGS.eval_type == "cnn":
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    results_dict = rouge_eval(self._rouge_ref_dir,
                                              self._rouge_dec_dir)
                    rouge_log(results_dict, self._decode_dir)
                if FLAGS.eval_type == "medsumm":
                    tf.logging.info(
                        "Writing generated answer summaries to file...")
                    with open(FLAGS.generated_data_file, "w",
                              encoding="utf-8") as f:
                        json.dump(self._generated_answers, f, indent=4)
                return

            question = batch.questions[0]
            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                if FLAGS.eval_type == "medsumm":
                    # Save generated data for answer summ evaluation:
                    self.write_data_for_medsumm_eval(
                        original_abstract_sents, decoded_words, question,
                        counter
                    )  # write ref summary and decoded summary to file, for later evaluation.
                    tf.logging.info("saving summary %i", counter)
                if FLAGS.eval_type == "cnn":
                    # write data for See's original evaluation done with pyrouge.
                    self.write_for_rouge(
                        original_abstract_sents, decoded_words, counter
                    )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
Beispiel #25
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        all_logits = []
        all_one_hot_lables = []
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                evaluation(all_logits, all_one_hot_lables,
                           FLAGS.max_side_steps)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp, pic, sticker_logits = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch)

            one_hot_targets = np.eye(
                FLAGS.max_side_steps)[batch.dec_pic_target]
            for i in range(FLAGS.batch_size):
                logits = sticker_logits[i].tolist()
                one_hot = one_hot_targets[i].tolist()
                for l, o in zip(logits, one_hot):
                    all_logits.append(l)
                    all_one_hot_lables.append(o)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter, pic,
                    batch.dec_pic_target[0]
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
Beispiel #26
0
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        attn_vis_idx = 0
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
            groundtruth_summ_sent_tokens = [
                sent.split(' ') for sent in groundtruth_summ_sents[0]
            ]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                sys_alp_list = groundtruth_article_lcs_paths_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                sys_ssi, sys_alp_list = util.replace_empty_ssis(
                    sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list)
            else:
                gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[
                    example_idx]
                sys_alp_list = ssi_functions.list_labels_from_probs(
                    sys_token_probs_list, FLAGS.tag_threshold)
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                # if gt_ssi != groundtruth_similar_source_indices_list:
                #     raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = '<u>System Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(sys_ssi):
                # selected_article_lcs_paths = None
                selected_article_lcs_paths = sys_alp_list[ssi_idx]
                ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                    ssi, selected_article_lcs_paths)
                selected_article_lcs_paths = [selected_article_lcs_paths]
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     selected_article_lcs_paths,
                                     FLAGS.batch_size, hps, self._vocab)

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string
                # article_withunks = data.show_art_oovs(original_article, self._vocab) # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

                if FLAGS.first_intact and ssi_idx == 0:
                    decoded_words = selected_article_text.strip().split()
                    decoded_output = selected_article_text
                else:
                    decoded_words, decoded_output, best_hyp = decode_example(
                        self._sess, self._model, self._vocab, batch,
                        example_idx, hps)
                    best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += highlighted_html + '<br>'

                if FLAGS.attn_vis and example_idx < 200:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx
                    )  # write info to .json file for visualization tool
                    attn_vis_idx += 1

                if len(final_decoded_words) >= 100:
                    break

            gt_ssi_list, gt_alp_list = util.replace_empty_ssis(
                groundtruth_similar_source_indices_list,
                raw_article_sents,
                sys_alp_list=groundtruth_article_lcs_paths_list)
            highlight_html_gt = '<u>Reference Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(gt_ssi_list):
                selected_article_lcs_paths = gt_alp_list[ssi_idx]
                try:
                    ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                        ssi, selected_article_lcs_paths)
                except:
                    util.print_vars(ssi, example_idx,
                                    selected_article_lcs_paths)
                    raise
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [
                        groundtruth_summ_sent_tokens[ssi_idx]
                    ]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_gt += highlighted_html + '<br>'

            if example_idx < 100 or (example_idx >= 2000
                                     and example_idx < 2100):
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                highlight_html_total = ssi_functions.put_html_in_two_columns(
                    highlight_html_total, highlight_html_gt)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            # if example_idx % 100 == 0:
            #     attn_dir = os.path.join(self._decode_dir, 'attn_vis_data')
            #     attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            # if FLAGS.attn_vis:
            #     self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            if FLAGS.dataset_name == 'xsum':
                l_param = 100
            else:
                l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
Beispiel #27
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0

        f = os.path.join(FLAGS.log_root, "output.txt")
        # print("----------------"+f)
        outputfile = codecs.open(f, "w", "utf8")
        output_result = []
        list_of_reference = []
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                logging.info("eval_finished")
                outputfile.close()
                break
            print(self._batcher.c_index)
            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            result, all_candidate = self.eval_one_batch_with_candidate(
                self._sess, self._model, self._vocab, batch)
            #result = self.eval_one_batch(self._sess, self._model, self._vocab, batch)

            for i, instance in enumerate(result):
                if i == len(batch.art_oovs):
                    break
                if i >= batch.real_length:
                    print("eval done with {} isntances".format(
                        len(output_result)))
                    break
                out_words = data.outputids2words(instance,
                                                 self._model._vocab_out,
                                                 batch.art_oovs[i])
                if data.STOP_DECODING in out_words:
                    out_words = out_words[:out_words.index(data.STOP_DECODING)]

                candidates_value = self.get_condidate_predicate(
                    out_words, all_candidate[i], batch.art_oovs[i])
                candidates_value = "_||_".join(candidates_value)

                output_now = " ".join(out_words)
                output_result.append(output_now)
                # refer = " ".join(refer)

                refer = batch.original_abstracts[i].strip()
                list_of_reference.append([refer])

                outputfile.write(batch.original_articles[i] + '\t' +
                                 batch.original_abstracts[i] + '\t' +
                                 output_now + '\t' + candidates_value + '\n')

        bleu = matrix.bleu_score(list_of_reference, output_result)
        acc = matrix.compute_acc(list_of_reference, output_result)

        print("bleu : {}   acc : {}".format(bleu, acc))
        return
Beispiel #28
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0

        f = os.path.join(FLAGS.log_root, "output.txt")
        # print("----------------"+f)
        outputfile = codecs.open(f, "w", "utf8")
        while True:
            batch = self._batcher.next_single_decode_batch()  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info("Decoder has finished reading dataset for single_pass.")
                # tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
                # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
                # rouge_log(results_dict, self._decode_dir)
                outputfile.close()
                return
            print(self._batcher.c_index)
            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

            article_withunks = data.show_art_oovs(original_article, self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab,
                                                   (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_k_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, FLAGS.best_k_hyp)
            for best_hyp in best_k_hyp:
                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(output_ids, self._vocab,
                                                     (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                #if FLAGS.single_pass:
                if True:
                    # print(original_article+"\t"+decoded_output + "\n")
                    outputfile.write(original_article + "\t" + original_abstract + "\t" + decoded_output + "\t" + str(
                        best_hyp.avg_log_prob) + "\t" + str(best_hyp.len) + "\n")
                    # self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
                    counter += 1  # this is how many examples we've decoded
                    outputfile.flush()
                else:
                    print_results(article_withunks, abstract_withunks, decoded_output)  # log output to screen
                    self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists,
                                           best_hyp.p_gens)  # write info to .json file for visualization tool

                    # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                    t1 = time.time()
                    if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                        tf.logging.info(
                            'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                            t1 - t0)
                        _ = util.load_ckpt(self._saver, self._sess)
                        t0 = time.time()
Beispiel #29
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""

    # Return original articles and their summaries if in API mode:
    if FLAGS.api_mode:
      articles = []
      summaries = []
      summaries_tokens = []

    t0 = time.time()
    counter = 0
    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        if not FLAGS.api_mode:
          results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
          rouge_log(results_dict, self._decode_dir)
          return
        else:
          assert FLAGS.pickle_file, "Pickle path not specified"
          decoder_final_output = {'articles':articles, 'summaries':summaries, 'summaries_tokens': summaries_tokens}
          pickle.dump(decoder_final_output, open(FLAGS.pickle_file, "wb"))
          return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.api_mode:
        # TODO add status printer:
        articles.append(original_article)
        summaries.append(decoded_output)
        summaries_tokens.append(decoded_words)
        counter += 1
        if not counter % 25:
          print(f"{counter} articles summarized")
      elif FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess)
          t0 = time.time()
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        total = len(glob.glob(self._batcher._data_path)) * 1000
        pbar = tqdm(total=total)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                logging.info("Output has been saved in %s and %s.",
                             self._rouge_ref_dir, self._rouge_dec_dir)
                if len(os.listdir(self._rouge_ref_dir)) != 0:
                    logging.info("Now starting ROUGE eval...")
                    results_dict = rouge_functions.rouge_eval(
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    rouge_functions.rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            all_original_abstract_sents = batch.all_original_abstracts_sents[0]
            raw_article_sents = batch.raw_article_sents[0]

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            decoded_words, decoded_output, best_hyp = decode_example(
                self._sess, self._model, self._vocab, batch, counter,
                self._batcher._hps)

            if FLAGS.single_pass:
                if counter < 1000:
                    self.write_for_human(raw_article_sents,
                                         all_original_abstract_sents,
                                         decoded_words, counter)
                rouge_functions.write_for_rouge(
                    all_original_abstract_sents,
                    None,
                    counter,
                    self._rouge_ref_dir,
                    self._rouge_dec_dir,
                    decoded_words=decoded_words
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.attn_vis:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, counter
                    )  # write info to .json file for visualization tool

                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
            pbar.update(1)
        pbar.close()
def run_decode(model, batcher, vocab):
    print "build graph..."
    model.build_graph()
    saver = tf.train.Saver(max_to_keep=3)
    sess = tf.Session(config=util.get_config())
    saver = tf.train.Saver()
    ckpt_path = util.load_ckpt(saver, sess)
    if FLAGS.single_pass:
        ckpt_name = "ckpt-" + ckpt_path.split('-')[-1]
        dirname = "decode_maxenc_%ibeam_%imindec_%imaxdec_%i" % (FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps)
        decode_dir = os.path.join(FLAGS.log_root, dirname + ckpt_name)
        if os.path.exists(decode_dir):
            raise Exception('single_pass decode directory %s should not exist', decode_dir)
    else:
        decode_dir = os.path.join(FLAGS.log_root, 'decode')
    if not os.path.exists(decode_dir): os.mkdir(decode_dir)
    if FLAGS.single_pass:
      rouge_ref_dir = os.path.join(decode_dir, "reference")
      if not os.path.exists(rouge_ref_dir): os.mkdir(rouge_ref_dir)
      rouge_dec_dir = os.path.join(decode_dir, "decoded")
      if not os.path.exists(rouge_dec_dir): os.mkdir(rouge_dec_dir)
    counter = 0
    t0 = time.time()
    while True:
        batch = batcher.next_batch()  # 1 example repeated across batch
        if batch is None:  # finished decoding dataset in single_pass mode
            assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
            print("Decoder has finished reading dataset for single_pass.")
            print("Output has been saved in %s and %s. Now starting ROUGE eval...", rouge_ref_dir,
                            rouge_dec_dir)
            results_dict = rouge_eval(rouge_ref_dir, rouge_dec_dir)
            rouge_log(results_dict, decode_dir)
            return

        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string
        original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

        article_withunks = data.show_art_oovs(original_article, vocab)  # string
        abstract_withunks = data.show_abs_oovs(original_abstract, vocab, None)  # string

        # Run beam search to get best Hypothesis
        output = model.run_beam_decode_step(sess, batch, vocab)
        output_ids = [int(t) for t in output]
        decoded_words = data.outputids2words(output_ids, vocab, None)

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        if FLAGS.single_pass:
            write_for_rouge(original_abstract_sents, decoded_words, counter, rouge_ref_dir, rouge_dec_dir)  # write ref summary and decoded summary to file, to eval with pyrouge later
            counter += 1  # this is how many examples we've decoded
        else:
            print_results(article_withunks, abstract_withunks, decoded_output)  # log output to screen
            # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
            t1 = time.time()
            if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                                t1 - t0)
                _ = util.load_ckpt(saver, sess)
                t0 = time.time()