def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) source_dir = os.path.join(data_dir, FLAGS.dataset) source_files = sorted(glob.glob(source_dir + '/*')) for i in range(4): ref_dir = os.path.join(log_dir, 'reference_' + str(i), 'reference') dec_dir = os.path.join(log_dir, 'reference_' + str(i), 'decoded') util.create_dirs(ref_dir) util.create_dirs(dec_dir) for source_idx, source_file in enumerate(source_files): human_summary_texts = get_human_summary_texts(source_file) summaries = [] for summary_text in human_summary_texts: summary = data.abstract2sents(summary_text) summaries.append(summary) candidate = summaries[i] references = [ summaries[idx] for idx in range(len(summaries)) if idx != i ] rouge_functions.write_for_rouge(references, candidate, source_idx, ref_dir, dec_dir) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir) # print("Results_dict: ", results_dict) rouge_functions.rouge_log(results_dict, os.path.join(log_dir, 'reference_' + str(i)))
def process_example(sents, ex_idx, groundtruth_summ_sents): final_decoded_words = [] for sent in sents: final_decoded_words.extend(sent.split(' ')) rouge_functions.write_for_rouge(groundtruth_summ_sents, None, ex_idx, rouge_ref_dir, rouge_dec_dir, decoded_words=final_decoded_words, log=False)
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, _, _ = ex print(example_idx) # Read example from dataset raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] enforced_groundtruth_ssi_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]] if FLAGS.upper_bound: # If upper bound, then get the groundtruth singletons/pairs replaced_ssi_list = util.replace_empty_ssis(enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists(replaced_ssi_list) summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) else: # Generates summary based on BERT output. This is an extractive summary. summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx) similar_source_indices_list_trunc = similar_source_indices_list[:ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[:ssi_length_extractive] if example_idx <= 1: summary_sent_tokens = [sent.split(' ') for sent in summary_sents_for_html_trunc] extracted_sents_in_article_html = html_highlight_sents_in_article(summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices) groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article(groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html ssi_functions.write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.summarizer == 'all': summary_methods = list(summarizers.keys()) else: summary_methods = [FLAGS.summarizer] if FLAGS.dataset_name == 'all': dataset_names = datasets else: dataset_names = [FLAGS.dataset_name] sheets_strs = [] for summary_method in summary_methods: summary_fn = summarizers[summary_method] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else '' vocab = Vocab('logs/vocab' + '_' + original_dataset_name, 50000) # create a vocabulary source_dir = os.path.join(data_dir, dataset_name) source_files = sorted( glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in dataset_name or 'newsroom' in dataset_name or 'xsum' in dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) if dataset_name == 'duc_2004': abs_source_dir = os.path.join( os.path.expanduser('~') + '/data/tf_data/with_coref', dataset_name) abs_example_generator = data.example_generator( abs_source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) abs_names_to_types = [('abstract', 'string_list')] triplet_ssi_list = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) if dataset_name == 'duc_2004': abs_example = next(abs_example_generator) groundtruth_summary_texts = util.unpack_tf_example( abs_example, abs_names_to_types) groundtruth_summary_texts = groundtruth_summary_texts[0] groundtruth_summ_sents_list = [[ sent.strip() for sent in data.abstract2sents(abstract) ] for abstract in groundtruth_summary_texts] else: groundtruth_summary_texts = [groundtruth_summary_text] groundtruth_summ_sents_list = [] for groundtruth_summary_text in groundtruth_summary_texts: groundtruth_summ_sents = [ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ] groundtruth_summ_sents_list.append( groundtruth_summ_sents) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) log_dir = os.path.join(log_root, dataset_name + '_' + summary_method) dec_dir = os.path.join(log_dir, 'decoded') ref_dir = os.path.join(log_dir, 'reference') util.create_dirs(dec_dir) util.create_dirs(ref_dir) parser = PlaintextParser.from_string( ' '.join(raw_article_sents), Tokenizer("english")) summarizer = summary_fn() summary = summarizer( parser.document, 5) #Summarize the document with 5 sentences summary = [str(sentence) for sentence in summary] summary_tokenized = [] for sent in summary: summary_tokenized.append(sent.lower()) rouge_functions.write_for_rouge(groundtruth_summ_sents_list, summary_tokenized, example_idx, ref_dir, dec_dir, log=False) decoded_sent_tokens = [ sent.split() for sent in summary_tokenized ] sentence_limit = 2 sys_ssi_list, _, _ = get_simple_source_indices_list( decoded_sent_tokens, article_sent_tokens, vocab, sentence_limit, min_matched_tokens) triplet_ssi_list.append( (groundtruth_similar_source_indices_list, sys_ssi_list, -1)) print('Evaluating Lambdamart model F1 score...') suffix = util.all_sent_selection_eval(triplet_ssi_list) print(suffix) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir) print(("Results_dict: ", results_dict)) sheets_str = rouge_functions.rouge_log(results_dict, log_dir, suffix=suffix) sheets_strs.append(dataset_name + '_' + summary_method + '\n' + sheets_str) for sheets_str in sheets_strs: print(sheets_str + '\n')
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, _, _ = ex print(example_idx) # example_idx += 1 qid = example_idx raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] enforced_groundtruth_ssi_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if FLAGS.upper_bound: replaced_ssi_list = util.replace_empty_ssis( enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists( replaced_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) elif FLAGS.lead: lead_ssi_list = [(idx, ) for idx in list( range(util.average_sents_for_dataset[FLAGS.dataset_name]))] lead_ssi_list = lead_ssi_list[:len( raw_article_sents )] # make sure the sentence indices don't go past the total number of sentences in the article selected_article_sent_indices = util.flatten_list_of_lists( lead_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = lead_ssi_list ssi_length_extractive = len(similar_source_indices_list) else: summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary( article_sent_tokens, qid_ssi_to_importances, example_idx) similar_source_indices_list_trunc = similar_source_indices_list[: ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[: ssi_length_extractive] if example_idx <= 100: summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html_trunc ] extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices) # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
def decode(self): """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals""" t0 = time.time() counter = 0 total = len(glob.glob(self._batcher._data_path)) * 1000 pbar = tqdm(total=total) while True: batch = self._batcher.next_batch( ) # 1 example repeated across batch if batch is None: # finished decoding dataset in single_pass mode assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode" logging.info( "Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval( self._rouge_ref_dir, self._rouge_dec_dir) rouge_functions.rouge_log(results_dict, self._decode_dir) return original_article = batch.original_articles[0] # string original_abstract = batch.original_abstracts[0] # string all_original_abstract_sents = batch.all_original_abstracts_sents[0] raw_article_sents = batch.raw_article_sents[0] article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs( original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, counter, self._batcher._hps) if FLAGS.single_pass: if counter < 1000: self.write_for_human(raw_article_sents, all_original_abstract_sents, decoded_words, counter) rouge_functions.write_for_rouge( all_original_abstract_sents, None, counter, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=decoded_words ) # write ref summary and decoded summary to file, to eval with pyrouge later if FLAGS.attn_vis: self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, counter ) # write info to .json file for visualization tool counter += 1 # this is how many examples we've decoded else: print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, counter) # write info to .json file for visualization tool # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint t1 = time.time() if t1 - t0 > SECS_UNTIL_NEW_CKPT: logging.info( 'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1 - t0) _ = util.load_ckpt(self._saver, self._sess) t0 = time.time() pbar.update(1) pbar.close()
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents) else: gt_ssi, sys_ssi, ext_len = ssi_list[example_idx] if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) if gt_ssi != groundtruth_similar_source_indices_list: print( 'Warning: Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '' for ssi_idx, ssi in enumerate(sys_ssi): selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, FLAGS.batch_size, hps, self._vocab) decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 1000: min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>' if len(final_decoded_words) >= 100: break if example_idx < 1000: self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): attn_vis_idx = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list sys_alp_list = groundtruth_article_lcs_paths_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) sys_ssi, sys_alp_list = util.replace_empty_ssis( sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list) else: gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[ example_idx] sys_alp_list = ssi_functions.list_labels_from_probs( sys_token_probs_list, FLAGS.tag_threshold) if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) # if gt_ssi != groundtruth_similar_source_indices_list: # raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '<u>System Summary</u><br><br>' for ssi_idx, ssi in enumerate(sys_ssi): # selected_article_lcs_paths = None selected_article_lcs_paths = sys_alp_list[ssi_idx] ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) selected_article_lcs_paths = [selected_article_lcs_paths] selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, selected_article_lcs_paths, FLAGS.batch_size, hps, self._vocab) original_article = batch.original_articles[0] # string original_abstract = batch.original_abstracts[0] # string article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs( original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string # article_withunks = data.show_art_oovs(original_article, self._vocab) # string # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string if FLAGS.first_intact and ssi_idx == 0: decoded_words = selected_article_text.strip().split() decoded_output = selected_article_text else: decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += highlighted_html + '<br>' if FLAGS.attn_vis and example_idx < 200: self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx ) # write info to .json file for visualization tool attn_vis_idx += 1 if len(final_decoded_words) >= 100: break gt_ssi_list, gt_alp_list = util.replace_empty_ssis( groundtruth_similar_source_indices_list, raw_article_sents, sys_alp_list=groundtruth_article_lcs_paths_list) highlight_html_gt = '<u>Reference Summary</u><br><br>' for ssi_idx, ssi in enumerate(gt_ssi_list): selected_article_lcs_paths = gt_alp_list[ssi_idx] try: ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) except: util.print_vars(ssi, example_idx, selected_article_lcs_paths) raise selected_raw_article_sents = util.reorder( raw_article_sents, ssi) if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [ groundtruth_summ_sent_tokens[ssi_idx] ] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_gt += highlighted_html + '<br>' if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) highlight_html_total = ssi_functions.put_html_in_two_columns( highlight_html_total, highlight_html_gt) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) # if example_idx % 100 == 0: # attn_dir = os.path.join(self._decode_dir, 'attn_vis_data') # attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later # if FLAGS.attn_vis: # self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.exp_name == 'extractive': for summary_method in summary_methods: if not os.path.exists( os.path.join(out_dir, summary_method, 'decoded')): os.makedirs(os.path.join(out_dir, summary_method, 'decoded')) if not os.path.exists( os.path.join(out_dir, summary_method, 'reference')): os.makedirs(os.path.join(out_dir, summary_method, 'reference')) print((os.path.join(out_dir, summary_method))) method_dir = os.path.join(summaries_dir, summary_method) file_names = sorted( [name for name in os.listdir(method_dir) if name[0] == 'd']) for art_idx, article_name in enumerate(tqdm(file_names)): file = os.path.join(method_dir, article_name) with open(file, 'rb') as f: lines = f.readlines() tokenized_sents = [[ token.lower() for token in nltk.tokenize.word_tokenize(line) ] for line in lines] sentences = [' '.join(sent) for sent in tokenized_sents] processed_summary = '\n'.join(sentences) out_name = '%06d_decoded.txt' % art_idx with open( os.path.join(out_dir, summary_method, 'decoded', out_name), 'wb') as f: f.write(processed_summary) reference_files = glob.glob( os.path.join(ref_dir, '%06d' % art_idx + '*')) abstract_sentences = [] for ref_file in reference_files: with open(ref_file) as f: lines = f.readlines() abstract_sentences.append(lines) rouge_functions.write_for_rouge( abstract_sentences, sentences, art_idx, os.path.join(out_dir, summary_method, 'reference'), os.path.join(out_dir, summary_method, 'decoded')) results_dict = rouge_functions.rouge_eval( ref_dir, os.path.join(out_dir, summary_method, 'decoded')) # print("Results_dict: ", results_dict) rouge_functions.rouge_log(results_dict, os.path.join(out_dir, summary_method)) for summary_method in summary_methods: print(summary_method) all_results = '' for summary_method in summary_methods: sheet_results_file = os.path.join(out_dir, summary_method, "sheets_results.txt") with open(sheet_results_file) as f: results = f.read() all_results += results print(all_results) a = 0 else: # source_dir = os.path.join(data_dir, FLAGS.dataset) log_root = os.path.join('logs', FLAGS.exp_name) ckpt_folder = util.find_largest_ckpt_folder(log_root) print(ckpt_folder) # if os.path.exists(os.path.join(log_dir,FLAGS.exp_name,ckpt_folder)): if ckpt_folder != 'decoded': summary_dir = os.path.join(log_dir, FLAGS.exp_name, ckpt_folder) else: summary_dir = os.path.join(log_dir, FLAGS.exp_name) ref_dir = os.path.join(summary_dir, 'reference') dec_dir = os.path.join(summary_dir, 'decoded') summary_files = glob.glob(os.path.join(dec_dir, '*_decoded.txt')) # summary_files = glob.glob(os.path.join(log_dir + FLAGS.exp_name, 'test_*.txt.result.summary')) # if len(summary_files) > 0 and not os.path.exists(dec_dir): # reformat files from extract + rewrite # os.makedirs(ref_dir) # os.makedirs(dec_dir) # for summary_file in tqdm(summary_files): # ex_index = extract_digits(os.path.basename(summary_file)) # new_file = os.path.join(dec_dir, "%06d_decoded.txt" % ex_index) # shutil.copyfile(summary_file, new_file) # # ref_files_to_copy = glob.glob(os.path.join(log_dir, FLAGS.dataset, ckpt_folder, 'reference', '*')) # for file in tqdm(ref_files_to_copy): # basename = os.path.basename(file) # shutil.copyfile(file, os.path.join(ref_dir, basename)) lengths = [] for summary_file in tqdm(summary_files): with open(summary_file) as f: summary = f.read() length = len(summary.strip().split()) lengths.append(length) print('Average summary length: %.2f' % np.mean(lengths)) print('Evaluating on %d files' % len(os.listdir(dec_dir))) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=FLAGS.l_param) rouge_functions.rouge_log(results_dict, summary_dir)
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings = ex print(example_idx) # example_idx += 1 qid = example_idx raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] enforced_groundtruth_ssi_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sent_tokens = [] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if FLAGS.upper_bound: replaced_ssi_list = util.replace_empty_ssis( enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists( replaced_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) else: summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, \ article_lcs_paths_list, token_probs_list = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx, qid_ssi_to_token_scores_and_mappings) similar_source_indices_list_trunc = similar_source_indices_list[: ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[: ssi_length_extractive] if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html_trunc ] if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0: lcs_paths_list_param = copy.deepcopy(article_lcs_paths_list) else: lcs_paths_list_param = None extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices, lcs_paths_list=lcs_paths_list_param) # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_ssi_list, gt_lcs_paths_list, gt_article_lcs_paths_list, gt_smooth_article_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=gt_lcs_paths_list, article_lcs_paths_list=gt_smooth_article_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html # all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive, token_probs_list)