示例#1
0
def output_to_classification_batch(output,batch, batcher, cla_batcher,cc):
    example_list =[]
    bleu =[]
    for i in range(FLAGS.batch_size):
        decoded_words_all = []



        output_ids = [int(t) for t in output[i]]
        decoded_words = data.outputids2words(output_ids, batcher._vocab, None)
        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words


        decoded_words_all = ' '.join(decoded_words).strip()  # single string


        decoded_words_all = decoded_words_all.replace("[UNK] ", "")
        decoded_words_all = decoded_words_all.replace("[UNK]", "")
       

        if decoded_words_all.strip() == "":
            bleu.append(0)
            new_dis_example = bc.Example(".", batch.score, cla_batcher._vocab, cla_batcher._hps)
            
        else:
            bleu.append(sentence_bleu([batch.original_reviews[i].split()],decoded_words_all.split(),smoothing_function=cc.method1))
            new_dis_example = bc.Example(decoded_words_all, batch.score, cla_batcher._vocab, cla_batcher._hps)
        example_list.append(new_dis_example)

    return bc.Batch(example_list, cla_batcher._hps, cla_batcher._vocab), bleu
示例#2
0
    def decodeOneSample(self, batches):

        batch = batches[0]
        original_article = batch.original_articles[0]
        original_abstract = batch.original_abstracts[0]
        original_abstract_sents = batch.original_abstracts_sents[0]

        article_withunks = data.show_art_oovs(original_article, self._vocab)
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab,
                                               batch.art_oovs[0])

        best_hypothesis = beam_search.run_beam_search(self._session,
                                                      self._model, self._vocab,
                                                      batch, self._hps)

        output_ids = [int(t) for t in best_hypothesis.tokens[1:]]
        decoded_words = data.outputids2words(output_ids, self._vocab,
                                             batch.art_oovs[0])
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words

        decoded_output = ' '.join(decoded_words)  # single string

        self.write_for_rouge(original_abstract_sents, decoded_words, 0,
                             original_article)
        self.rouge_eval()
        print_results(article_withunks, abstract_withunks, decoded_output)
        self.write_for_attnvis(article_withunks, abstract_withunks,
                               decoded_words, best_hypothesis.attn_dists,
                               best_hypothesis.p_gens)
示例#3
0
    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:  #  and counter <= 100 # 11490
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 10 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()
            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
示例#4
0
	def decode(self):
		"""Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
		t0 = time.time()
		counter = 0
		while True:
			batch = self._batcher.next_batch()  # 1 example repeated across batch
#			if counter == 10:
#				batch = None
			if batch is None: # finished decoding dataset in single_pass mode
				d = [all_attn_dists,all_pgens,dec_words]
				output_fname = os.path.join(self._decode_dir, 'attn_dist_p_gens_data.pkl')
				pickle.dump(d,open(output_fname,'wb'))
				assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
				tf.logging.info("Decoder has finished reading dataset for single_pass.")
				tf.logging.info("Output has been saved in %s and %s. Now starting eval...", self._rouge_ref_dir, self._rouge_dec_dir)
				#Metrics here
				self.get_metrics(self._rouge_ref_dir,self._rouge_dec_dir)
				return

			original_article = batch.original_articles[0]  # string
			original_query = batch.original_queries[0]  # string
			original_abstract = batch.original_abstracts[0]  # string
			original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

			article_withunks = data.show_art_oovs(original_article, self._vocab) # string
			#query_withunks = data.show_art_oovs(original_article, self._vocab) # string
			abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

			# Run beam search to get best Hypothesis
			best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

			# Extract the output ids from the hypothesis and convert back to words
			output_ids = [int(t) for t in best_hyp.tokens[1:]]
			decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

			# Remove the [STOP] token from decoded_words, if necessary
			try:
				fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
				decoded_words = decoded_words[:fst_stop_idx]
			except ValueError:
				decoded_words = decoded_words
			decoded_output = ' '.join(decoded_words) # single string

			if FLAGS.single_pass:
				self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
				counter += 1 # this is how many examples we've decoded
				all_attn_dists.append(best_hyp.attn_dists)
				all_pgens.append(best_hyp.p_gens)
				dec_words.append(decoded_words)
#				self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens,count=counter)
			else:
				print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
				self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

				# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
				t1 = time.time()
				if t1-t0 > SECS_UNTIL_NEW_CKPT:
					tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
					_ = util.load_ckpt(self._saver, self._sess)
					t0 = time.time()
示例#5
0
    def decode_one_question(self, batch):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0

        f = os.path.join(FLAGS.log_root, "output.txt")
        # print("----------------"+f)
        outputfile = codecs.open(f, "w", "utf8")
        output_result = []
        list_of_reference = []
        print(self._batcher.c_index)

        # Run beam search to get best Hypothesis
        result = self.eval_one_batch(self._sess, self._model, self._vocab,
                                     batch)

        i = 0
        out_words = data.outputids2words(result[i], self._model._vocab_out,
                                         batch.art_oovs[i])
        if data.STOP_DECODING in out_words:
            out_words = out_words[:out_words.index(data.STOP_DECODING)]
        output_now = " ".join(out_words)
        output_result.append(output_now)
        # refer = " ".join(refer)

        refer = batch.original_abstracts[i].strip()
        list_of_reference.append([refer])

        return batch.original_articles[i], batch.original_abstracts[
            i], output_now
示例#6
0
    def predict(self,input_data):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely,
         loading latest checkpoint at regular intervals
         """

        batch = input_data

        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(output_ids, self._vocab,
                                             (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words

        for i in range(len(decoded_words)):
            if not type(decoded_words[i]) is str:
                decoded_words[i] = str(decoded_words[i], encoding='utf-8')

        decoded_output = " ".join(decoded_words)  # single string
        return decoded_output
示例#7
0
  def process_one_article(self, original_article_sents, original_abstract_sents, \
                          original_selected_ids, output_ids, oovs, \
                          attn_dists, p_gens, log_probs, counter):
    # Remove the [STOP] token from decoded_words, if necessary
    decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words) # single string
    decoded_sents = data.words2sents(decoded_words)

    if FLAGS.single_pass:
      verbose = False if FLAGS.mode == 'eval' else True
      self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
      if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
        original_article = ' '.join(original_article_sents)
        original_abstract = ' '.join(original_abstract_sents)
        article_withunks = data.show_art_oovs(original_article, self._vocab) # string
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, \
                               attn_dists, p_gens, log_probs, counter, verbose)
      if FLAGS.save_pkl:
        self.save_result(original_article_sents, original_abstract_sents, \
                         original_selected_ids, decoded_sents, counter, verbose)
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    step = 0
    while True:
      step += 1
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert self._model.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return


      t_start_decode = time.time()
      
      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if self._model.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      best_hyp = self._model.model_decode(batch, self._sess, self._vocab)
      print(best_hyp.tokens)
      #best_hyp = None

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if self._model.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if self._model.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          print('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_checkpoint(self._saver, self._sess)
          t0 = time.time()

      t_end_decode = time.time()
      print('decode {}-th batch requires {} seconds'.format(step, int(t_end_decode - t_start_decode)))
示例#9
0
    def decode(self, ckpt_file=None):
        FLAGS = self._FLAGS

        # load latest checkpoint
        misc_utils.load_ckpt(self._saver, self._sess, self._ckpt_dir,
                             ckpt_file)

        counter = 0
        f = open(self._decode_dir, "w")
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                break

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch, FLAGS)

            # Extract the output ids from the hypothesis and convert back to
            # words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                # index of the (first) [STOP] symbol
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            # write ref summary and decoded summary to file, to eval with
            # pyrouge later
            # self.write_for_rouge(original_abstract_sents, decoded_words, counter)
            processed = self.depreciated_processing(decoded_words)
            f.write(processed + "\n")

            counter += 1
            if counter % 100 == 0:
                print("%d sentences decoded" % counter)

        f.close()
示例#10
0
    def generator_train_negative_example(self):

        counter = 0
        step = 0

        t0 = time.time()
        batches = self.batches

        while step < 1000:
            
            batch = batches[step]
            step += 1

            decode_result = self._model.run_eval_given_step(self._sess, batch)

            for i in range(FLAGS.batch_size):
                decoded_words_all = []
                original_review = batch.original_review_output[i]  # string

                for j in range(FLAGS.max_dec_sen_num):

                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
                        decoded_words.append('.')
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)

                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)

                self.write_negtive_to_json(original_review, decoded_words_all, counter, self.train_sample_whole_positive_dir, self.train_sample_whole_negative_dir)

                counter += 1  # this is how many examples we've decoded
    def generator_test_negative_example(self):
    
        counter = 0
        step = 0

        # t0 = time.time()
        batches = self.test_batches
        print(len(batches))
        while step < len(batches):
            
            batch = batches[step]
            step += 1

            decode_result =self._model.greedy_example(self._sess, batch)

            for i in range(FLAGS.batch_size):
                decoded_words_all = []
                original_review = batch.original_review_output[i]  # string
                for j in range(FLAGS.max_dec_sen_num):
                    output_ids = [int(t) for t in decode_result['Greedy_outputs'][i][0:]]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # print("decoded_words :",decoded_words)
                    if decoded_words[0] == '[STOPDOC]':
                        decoded_words = decoded_words[1:]
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words) < 2:
                        continue
                    if len(decoded_words_all) > 0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all) - 1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
                        decoded_words.append('.')
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)

                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                    
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)

                self.write_negtive_to_json(batch.original_review_inputs[i], original_review, decoded_words_all, counter, self.test_sample_whole_positive_dir,self.test_sample_whole_negative_dir)

                counter += 1  # this is how many examples we've decoded
示例#12
0
    def pair_wise_decode(self):
        f = os.path.join(FLAGS.data_path, "output.txt")
        outputfile = codecs.open(f, "w", "utf8")
        output_result = []
        list_of_reference = []
        while True:
            batch = self._batcher.next_pairwised_decode_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                logging.info("eval_finished")
                outputfile.close()
                break
            print(self._batcher.c_index)
            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            result = self.eval_one_batch(self._sess, self._model, self._vocab,
                                         batch)
            # result = self.eval_one_batch(self._sess, self._model, self._vocab, batch)

            for i, instance in enumerate(result):
                if i == len(batch.art_oovs):
                    break
                if i >= batch.real_length:
                    print("eval done with {} isntances".format(
                        len(output_result)))
                    break
                out_words = data.outputids2words(instance,
                                                 self._model._vocab_out,
                                                 batch.art_oovs[i])
                if data.STOP_DECODING in out_words:
                    out_words = out_words[:out_words.index(data.STOP_DECODING)]

                output_now = " ".join(out_words)
                output_result.append(output_now)
                # refer = " ".join(refer)

                refer = batch.original_abstracts[i].strip()
                list_of_reference.append([refer])

                outputfile.write(batch.original_articles[i] + '\t' +
                                 output_now + '\n')

        bleu = matrix.bleu_score(list_of_reference, output_result)
        acc = matrix.compute_acc(list_of_reference, output_result)

        print("bleu : {}   acc : {}".format(bleu, acc))
        return
示例#13
0
def compute_reward(batch, decode_batch, vocab, mode, use_cuda):
    target_sents = batch.original_abstracts  # list of string

    # Back to CPU.
    decode_batch = decode_batch.cpu().numpy(
    )  # B x S x L     batch size * sample size  * sentence length

    # print("decode_batch")
    # print(len(decode_batch))
    # print(len(decode_batch[0]))

    output_ids = decode_batch[:, :, 1:]

    from data import outputids2words

    #
    # print("output_ids")
    # print(output_ids[0][0])
    # temp = outputids2words(list(map(lambda x : x.item(), decode_batch[0][0])),vocab,None)
    # print(temp)

    all_rewards = torch.zeros((config.batch_size, config.sample_size))  # B x S
    if use_cuda: all_rewards = all_rewards.cuda()

    # Emm loop.
    for i in range(config.batch_size):
        for j in range(config.sample_size):
            words = data.outputids2words(
                list(output_ids[i, j, :]), vocab,
                (batch.art_oovs[i] if config.pointer_gen else None))
            # Remove the [STOP] token from decoded_words, if necessary
            # 机智。
            try:
                fst_stop_idx = words.index(data.STOP_DECODING)
                words = words[:fst_stop_idx]
            except ValueError:
                words = words
            decode_sent = ' '.join(words)
            all_rewards[i, j] = rouge_2(target_sents[i], decode_sent)
    batch_avg_reward = torch.mean(all_rewards, dim=1, keepdim=True)  # B x 1

    ones = torch.ones((config.batch_size, config.sample_size))
    if use_cuda: ones = ones.cuda()
    if mode == 'MLE':
        return ones, torch.zeros(1)
    else:
        batch_avg_reward = batch_avg_reward * ones  # B x S
        if torch.equal(all_rewards, batch_avg_reward):
            all_rewards = all_rewards
        else:
            all_rewards = all_rewards - batch_avg_reward

        for i in range(config.batch_size):
            for j in range(config.sample_size):
                if all_rewards[i, j] < 0:
                    all_rewards[i, j] = 0
        return all_rewards, batch_avg_reward.mean()
  def _decode(self):
    """
    """
    t0 = time.time()
    counter = 0
    while True:
      batch = self.batcher._next_batch()
      if batch is None: 
        assert FLAGS.onetime, "Dataset exhausted, but we are not in onetime mode"
        print('INFO: Decoder has finished reading dataset for onetime.')
        print('INFO: Output has been saved in {} and {}, start ROUGE eval...'.format(self.rouge_ref_dir, self.rouge_dec_dir))
        results_dict = rouge_eval(self.rouge_ref_dir, self.rouge_dec_dir)
        rouge_log(results_dict, self.decode_dir)
        return

      original_article = batch.original_articles[0]
      original_abstract = batch.original_abstracts[0]

      article_withunks = data.show_art_oovs(original_article, self.vocab)
      abstract_withunks = data.show_abs_oovs(original_abstract, self.vocab, (batch.art_oovs[0] if FLAGS.pointer else None))

      best_hyp = beam_search.run_beam_search(self.sess, self.model, self.vocab, batch)

      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self.vocab, (batch.art_oovs[0] if FLAGS.pointer else None))

      try:
        fst_stop_idx = decoded_words.index(data.DECODING_END)
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words)

      if FLAGS.onetime:
        self._write_for_rouge(original_abstract, decoded_words, counter)
        counter += 1
      else:
        print("")
        print('INFO: ARTICLE: {}'.format(article_withunks))
        print('INFO: REFERENCE SUMMARY: {}'.format(abstract_withunks))
        print('INFO: GENERATED SUMMARY: {}'.format(decoded_output))
        print("")
        self._write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.pointers)

        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          print('INFO: Decoding for {} seconds, loading new checkpoint'.format(t1-t0))
          while True:
            try:
              ckpt_state = tf.train.get_checkpoint_state(train_dir)
              print('INFO: Loading checkpoint {}'.format(ckpt_state.model_checkpoint_path))
              self.saver.restore(self.sess, ckpt_state.model_checkpoint_path)
              break
            except:
              print('ERROR: Failed to restore checkpoint: {}, sleep for {} secs'.format(train_dir, 10))
              time.sleep(10)
          t0 = time.time()
    def generator_sample_example(self, positive_dir, negative_dir, num_batch):
        self.temp_positive_dir = positive_dir
        self.temp_negative_dir = negative_dir

        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        shutil.rmtree(self.temp_negative_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        counter = 0

        for i in range(num_batch):
            decode_result = self._model.run_eval_given_step(self._sess, self.batches[self.current_batch])
            for i in range(FLAGS.batch_size):
                decoded_words_all = []
                original_review = self.batches[self.current_batch].original_review_output[i]
                for j in range(FLAGS.max_dec_sen_num):
                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words
                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)

                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "! ", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ", decoded_words_all)
                self.write_negtive_temp_to_json(original_review, decoded_words_all, counter)
                counter += 1  # this is how many examples we've decoded
            self.current_batch +=1
            if self.current_batch >= len(self.batches):
                self.current_batch = 0
        
        eva = Evaluate()
        eva.diversity_evaluate(negative_dir + "/*")
示例#16
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    out_num=0
    summaries=[]
    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert self.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")

      ##I commented those lines
        # tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)
        print out_num
        return summaries

      original_article = batch.original_articles[0]  # string
      # I commented those lines
      # original_abstract = batch.original_abstracts[0]  # string
      # original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      # I commented this line
      # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch,self.beam_size,self.max_dec_steps,self.min_dec_steps)

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if self.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if self.single_pass:
        summaries.append(decoded_output)
        # open('s'+str(out_num)+'.txt','w').write(decoded_output)
        # with open('output'+str(out_num)+'.txt','w') as output:
          # output.write(original_article+'\n*******************************************\n\n'+decoded_output)
        out_num+=1
        print out_num
        #this line is commented by me
        # self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
示例#17
0
    def decode(self):
        start = time.time()
        counter = 0
        bleu_scores = []
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_articles = batch.original_articles[0]
            original_abstracts = batch.original_abstracts_sents[0]
            reference = original_abstracts[0].strip().split()
            bleu = nltk.translate.bleu_score.sentence_bleu([reference], decoded_words, weights = (0.5, 0.5))
            bleu_scores.append(bleu)

            # write_for_rouge(original_abstracts, decoded_words, counter,
            #                 self._rouge_ref_dir, self._rouge_dec_dir)

            write_for_result(original_articles, original_abstracts, decoded_words, \
                                                self._result_path, self.data_class)

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()
        
        '''
        # uncomment this if you successfully install `pyrouge`
        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        '''

        if self.data_class == 'val':
            print('Average BLEU score:', np.mean(bleu_scores))
            with open(self._result_path, "a") as f:
                print('Average BLEU score:', np.mean(bleu_scores), file=f)
示例#18
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      if FLAGS.ac_training:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time()
示例#19
0
    def decode2(self, batcher):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = batcher.next_batch()  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            return decoded_output
    def decode(self, batch):
        best_summary = self.beam_search(batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_summary.tokens[1:]]
        decoded_words = data.outputids2words(output_ids, self.vocab,
                                                (batch.art_oovs[0] if config.pointer_gen else None))
        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        return "".join(decoded_words)
示例#21
0
    def seq_output_to_batch(self, decode_result_seq, batch):

        for i in range(FLAGS.batch_size):

            #original_review = batch.original_review_outputs[i]

            output_ids = [int(t)
                          for t in decode_result_seq['generated'][i]][0:]
            decoded_words = data.outputids2words(output_ids, self._vocab, None)

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            decoded_output = ' '.join(decoded_words).strip()  # single string

            try:
                fst_stop_idx = decoded_output.index(
                    data.STOP_DECODING_DOCUMENT
                )  # index of the (first) [STOP] symbol
                decoded_output = decoded_output[:fst_stop_idx]
            except ValueError:
                decoded_output = decoded_output
                decoded_output = decoded_output.replace("[UNK] ", "")
                decoded_output = decoded_output.replace("[UNK]", "")
                decoded_output, _ = re.subn(r"(! ){2,}", "! ", decoded_output)
                decoded_output, _ = re.subn(r"(\. ){2,}", ". ", decoded_output)

            abstract_sen_words = tokenize.word_tokenize(decoded_output.strip())
            if len(abstract_sen_words) > FLAGS.max_enc_steps:
                abstract_sen_words = abstract_sen_words[:FLAGS.max_enc_steps]

            # abstract_words = abstract.split() # list of strings
            enc_ids = [
                self._vocab.word2id(w) for w in abstract_sen_words
            ]  # list of word ids; OOVs are represented by the id for UNK token

            batch.enc_lens[i] = batch.enc_lens[i] + 1
            batch.enc_sen_lens[i][batch.enc_lens[i] - 1] = len(enc_ids)

            while len(enc_ids) < FLAGS.max_enc_steps:
                enc_ids.append(self._vocab.word2id(data.PAD_TOKEN))

            batch.enc_batch[i, batch.enc_lens[i] - 1, :] = enc_ids

        return batch
示例#22
0
def tokens_to_continuous_text(tokens, vocab, art_oovs):
    words = data.outputids2words(tokens, vocab, art_oovs)
    text = ' '.join(words)
    # text = text.decode('utf8')
    split_text = text.split(' ')
    if len(split_text) != len(words):
        for i in range(min(len(words), len(split_text))):
            try:
                print '%s\t%s'%(words[i], split_text[i])
            except:
                print 'FAIL\tFAIL'
        raise Exception('text ('+str(len(text.split()))+
                        ') does not have the same number of tokens as words ('+str(len(words))+')')

    return text
示例#23
0
    def output_to_batch(self, current_batch, result):

        srl_example_list = []
        decode_mask = []

        for i in range(FLAGS.batch_size):

            output_ids = [int(t) for t in result['generated'][i]][0:]
            decoded_words = data.outputids2words(output_ids,
                                                 self._batcher._vocab, None)
            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            decoded_output = ' '.join(decoded_words).strip()  # single string

            try:
                fst_stop_idx = decoded_output.index(
                    data.STOP_DECODING_DOCUMENT
                )  # index of the (first) [STOP] symbol
                decoded_output = decoded_output[:fst_stop_idx]
            except ValueError:
                decoded_output = decoded_output
            decoded_output = decoded_output.replace("[UNK] ", "")
            decoded_output = decoded_output.replace("[UNK]", "")
            decoded_output, _ = re.subn(r"(! ){2,}", "", decoded_output)
            decoded_output, _ = re.subn(r"(\. ){2,}", "", decoded_output)

            if decoded_output.strip() == "":
                new_dis_example = Srl_Example(
                    current_batch.original_review_outputs[i], "was",
                    self._srl_batcher._vocab, self._srl_batcher._hps)
                decode_mask.append(0)

            else:
                new_dis_example = Srl_Example(
                    current_batch.original_review_outputs[i], decoded_output,
                    self._srl_batcher._vocab, self._srl_batcher._hps)
                decode_mask.append(1)

            srl_example_list.append(new_dis_example)

        return Srl_Batch(srl_example_list, self._srl_batcher._hps,
                         self._srl_batcher._vocab), decode_mask
示例#24
0
    def decode_one_batch(self, batch, withRouge=True):
        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string
        original_abstract_sents = batch.original_abstracts_sents[
            0]  # list of strings
        original_uuid = batch.uuids[0]  # string

        article_withunks = data.show_art_oovs(original_article,
                                              self._vocab)  # string
        abstract_withunks = data.show_abs_oovs(
            original_abstract, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                               self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        if FLAGS.single_pass:
            # write ref summary and decoded summary to file, to eval with pyrouge later
            # self.write_for_rouge(original_abstract_sents, decoded_words,self._counter)
            self.write_for_flink(original_uuid, original_article,
                                 decoded_words, original_abstract_sents)
            self._counter += 1  # this is how many examples we've decoded
        else:
            print_results(article_withunks, abstract_withunks,
                          decoded_output)  # log output to screen
            self.write_for_attnvis(
                article_withunks, abstract_withunks, decoded_words,
                best_hyp.attn_dists, best_hyp.p_gens
            )  # write info to .json file for visualization tool
def decode_example(sess, model, vocab, batch, counter, hps):
    # Run beam search to get best Hypothesis
    best_hyp = beam_search.run_beam_search(sess, model, vocab, batch, counter,
                                           hps)

    # Extract the output ids from the hypothesis and convert back to words
    output_ids = [int(t) for t in best_hyp.tokens[1:]]
    decoded_words = data.outputids2words(
        output_ids, vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

    # Remove the [STOP] token from decoded_words, if necessary
    try:
        fst_stop_idx = decoded_words.index(
            data.STOP_DECODING)  # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
        decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words)  # single string
    return decoded_words, decoded_output, best_hyp
示例#26
0
 def _calc_topic_loss(self, greedy_search_samples, original_topic_dist,
                      art_oovs):
     topic_similarity = []
     for idx in range(self._hps.batch_size):
         sample_ids = greedy_search_samples[idx]
         abs_words = data.outputids2words(sample_ids, self._vocab,
                                          art_oovs[idx])
         try:
             end_idx = abs_words.index(data.STOP_DECODING)
             abs_words = abs_words[:end_idx]
         except ValueError:
             abs_words = abs_words
         if len(abs_words) < 2:
             abs_words = ["xxx"]
         abs_str = ' '.join(abs_words)
         abs_topic_dis = lda_distribution(abs_str)
         art_topic_dis = original_topic_dist[idx]
         similarity = cal_similarity(abs_topic_dis, art_topic_dis)
         topic_similarity.append(similarity)
     return np.mean(topic_similarity)
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0
    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        return

      original_article = batch.original_articles[0]  # string

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string

      # Run beam search to get best Hypothesis
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, decoded_output) # log output to screen
        
        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess)
          t0 = time.time()
示例#28
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    
    t0 = time.time()
    counter = 0
    while True:
      batch = self._batcher.next_batch()

      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return
      best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch) 
      # Run beam search to get best Hypothesis
      #best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)  ### I am here now @@@@@

      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      #decoded_words = data.outputids2words(output_ids, self._vocab)
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
      print (decoded_words)
      try:
          fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
          decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
          decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string
      print (decoded_output)
      #### print actual test input when decoding
      #print (batch.enc_batch_field)
      original_abstract_sents = batch.original_summaries_sents[0]  # list of strings
      print (original_abstract_sents)
      if FLAGS.single_pass:
          self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
          counter += 1 # this is how many examples we've decoded
      #results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
      #rouge_log(results_dict, self._decode_dir)
    '''
示例#29
0
    def run_beam_decoder(self,batch):
        original_article=batch.original_articles[0]
        original_abstract=batch.original_abstracts[0]

        article_withunks=data.show_art_oovs(original_article,self.vocab)
        abstract_withunks=data.show_abs_oovs(original_abstract,self.vocab,(batch.art_oovs[0] if self.hp.pointer_gen==True else None))

        best_list=beam_search.run_beam_search(self.sess,self.model,self.vocab,batch)
        print(best_list.tokens)
        output_ids=[int(t) for t in best_list.tokens[1:]]
        decoded_words=data.outputids2words(output_ids,self.vocab,(batch.art_oovs[0] if self.hp.pointer_gen==True else None))
        print(decoded_words)

        try:
            fst_stop_idx=decoded_words.index(data.STOP_DECODING)
            decoded_words=decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words=decoded_words
        decoded_output=' '.join(decoded_words)

        return article_withunks,abstract_withunks,decoded_output
示例#30
0
    def decode(self):
        start = time.time()
        counter = 0
        bleu_scores = []
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstracts = batch.original_abstracts_sents[0]
            reference = original_abstracts[0].strip().split()
            bleu = nltk.translate.bleu_score.sentence_bleu([reference],
                                                           decoded_words,
                                                           weights=(0.5, 0.5))
            bleu_scores.append(bleu)

            write_for_rouge(original_abstracts, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print('Average BLEU score:', np.mean(bleu_scores))
        '''
def run_flip(model, batcher, vocab):
  """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  model.build_graph() # build the graph
  saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
  sess = tf.Session(config=util.get_config())

  while True:
    _ = util.load_ckpt(saver, sess) # load a new checkpoint
    batch = batcher.next_batch() # get the next batch

    # run eval on the batch
    t0=time.time()
    # results = model.run_eval_step(sess, batch)
    results = model.run_decode(sess, batch)
    t1=time.time()
    tf.logging.info('seconds for batch: %.2f', t1-t0)
    # output_ids = [int(t) for t in best_hyp.tokens[1:]]
    # decoded_words = data.outputids2words(output_ids, self._vocab, (
    # batch.art_oovs[0] if FLAGS.pointer_gen else None))
    stop_id = model._vocab.word2id('[STOP]')

    for b in range(len(batch.original_abstracts)):
      original_abstract = batch.original_abstracts[b] # string
      fst_stop_idx = np.where(batch.target_batch[b,:] == stop_id)[0]
      if len(fst_stop_idx):
        fst_stop_idx = fst_stop_idx[0]
      else:
        fst_stop_idx = len(batch.target_batch[b,:])

      abstract_withunks = data.show_abs_oovs(original_abstract, model._vocab, None)  # string
      tf.logging.info('REFERENCE SUMMARY: %s', abstract_withunks)

      output_ids = [int(x) for x in results['ids'][:, b, 0]]
      output_ids = output_ids[:fst_stop_idx]
      decoded_words = data.outputids2words(output_ids, model._vocab, None)
      decoded_output = ' '.join(decoded_words)
      tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
      loss = results['loss']
      tf.logging.info('loss %.2f max flip %d dec %d target %d', loss, max(output_ids), batch.dec_batch.max(), batch.target_batch.max())
示例#32
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        # t0 = time.time()
        batch = self._batcher.next_batch()  # 1 example repeated across batch

        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string

        # input data
        article_withunks = data.show_art_oovs(original_article,
                                              self._vocab)  # string
        abstract_withunks = data.show_abs_oovs(
            original_abstract, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                               self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self._vocab,
            (batch.art_oovs[0] if FLAGS.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(
                data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        # tf.logging.info('ARTICLE:  %s', article)
        #  tf.logging.info('GENERATED SUMMARY: %s', decoded_output)

        sys.stdout.write(decoded_output)
示例#33
0
    def compute_BLEU(self, train_step):

        counter = 0
        step = 0


        t0 = time.time()
        batches = self.test_batches
        list_hop = []
        list_ref = []

        #tf.logging.info(len(batches))

        while step <  100:
            #tf.logging.info(step)


            batch = batches[step]
            step += 1

            decode_result = self._model.run_eval_given_step(self._sess, batch)

            #tf.logging.info(step)

            for i in range(FLAGS.batch_size):

                #tf.logging.info("i: " + str(i))

                decoded_words_all = []
                original_review = batch.original_review_output[i]  # string

                for j in range(FLAGS.max_dec_sen_num):

                    #tf.logging.info("j: " + str(j))

                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words)<2:
                        continue

                    '''if j>0:
                        new_set1 =set(decoded_words_all[j-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set1):
                            continue'''
                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    if decoded_words[-1] !='.' and decoded_words[-1] !='!' and decoded_words[-1] !='?':
                        decoded_words.append('.')

                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
                decoded_words_all,_ = re.subn(r"(\. ){2,}", "", decoded_words_all)


                list_hop.append(decoded_words_all)
                list_ref.append(original_review)
                #self.write_negtive_to_json(original_review, decoded_output, counter)

                #counter += 1  # this is how many examples we've decoded
        '''file_temp = open(train_step+"_temp_result.txt",'w')
        for hop in list_hop:
            file_temp.write(hop+"\n")
        file_temp.close()'''
        '''new_ref_list =[]
        for ref in list_ref:
            sens = nltk.sent_tokenize(ref)
            for sen in sens:
                new_ref_list.append(nltk.word_tokenize(sen))
        t0 = time.time()
        new_sen_list =[]
        new_ref_ref =[]
        for hop in list_hop:
            sens = nltk.sent_tokenize(hop)
            for sen in sens:
                new_sen_list.append(nltk.word_tokenize(sen))

                new_ref_ref.append(new_ref_list)'''

        #print (new_sen_list)



        #bleu_score = corpus_bleu(new_ref_ref, new_sen_list)
        t1 = time.time()
        tf.logging.info('seconds for test generator: %.3f ', (t1 - t0))
        return 0
示例#34
0
    def generator_sample_example(self, positive_dir, negative_dir, num_batch):

        self.temp_positive_dir = positive_dir
        self.temp_negative_dir = negative_dir

        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        shutil.rmtree(self.temp_negative_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        counter = 0


        for i in range(num_batch):
            decode_result = self._model.run_eval_given_step(self._sess, self.batches[self.current_batch])


            for i in range(FLAGS.batch_size):

                decoded_words_all = []
                original_review = self.batches[self.current_batch].original_review_output[i]

                for j in range(FLAGS.max_dec_sen_num):

                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "! ", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ", decoded_words_all)
                self.write_negtive_temp_to_json(original_review, decoded_words_all, counter)

                counter += 1  # this is how many examples we've decoded
            self.current_batch +=1
            if self.current_batch >= len(self.batches):
                self.current_batch = 0
        
        eva = Evaluate()
        eva.diversity_evaluate(negative_dir + "/*")
示例#35
0
文件: main.py 项目: dbolshak/DPGAN
def output_to_batch(current_batch, result, batcher, dis_batcher):
    example_list= []
    db_example_list = []

    for i in range(FLAGS.batch_size):
        decoded_words_all = []
        encode_words = current_batch.original_review_inputs[i]

        for j in range(FLAGS.max_dec_sen_num):

            output_ids = [int(t) for t in result['generated'][i][j]][1:]
            decoded_words = data.outputids2words(output_ids, batcher._vocab, None)
            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            if len(decoded_words) < 2:
                continue
            if len(decoded_words_all) > 0:
                new_set1 = set(decoded_words_all[len(decoded_words_all) - 1].split())
                new_set2 = set(decoded_words)
                if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                    continue
            if decoded_words[-1] != '.' and decoded_words[-1] != '!' and decoded_words[-1] != '?':
                decoded_words.append('.')
            decoded_output = ' '.join(decoded_words).strip()  # single string
            decoded_words_all.append(decoded_output)


        decoded_words_all = ' '.join(decoded_words_all).strip()
        try:
            fst_stop_idx = decoded_words_all.index(
                data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
            decoded_words_all = decoded_words_all[:fst_stop_idx]
        except ValueError:
            decoded_words_all = decoded_words_all
        decoded_words_all = decoded_words_all.replace("[UNK] ", "")
        decoded_words_all = decoded_words_all.replace("[UNK]", "")
        decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
        decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)

        if decoded_words_all.strip() == "":
            '''tf.logging.info("decode")
            tf.logging.info(current_batch.original_reviews[i])
            tf.logging.info("encode")
            tf.logging.info(encode_words)'''
            new_dis_example = bd.Example(current_batch.original_review_output[i], -0.0001, dis_batcher._vocab, dis_batcher._hps)
            new_example = Example(current_batch.original_review_output[i],  batcher._vocab, batcher._hps,encode_words)

        else:
            '''tf.logging.info("decode")
            tf.logging.info(decoded_words_all)
            tf.logging.info("encode")
            tf.logging.info(encode_words)'''
            new_dis_example = bd.Example(decoded_words_all, 1, dis_batcher._vocab, dis_batcher._hps)
            new_example = Example(decoded_words_all, batcher._vocab, batcher._hps,encode_words)
        example_list.append(new_example)
        db_example_list.append(new_dis_example)

    return Batch(example_list, batcher._hps, batcher._vocab), bd.Batch(db_example_list, dis_batcher._hps, dis_batcher._vocab)