def save_as_txt_file(ex): # example_idx += 1 # if num_instances != -1 and example_idx >= num_instances: # break example, example_idx = ex # print example_idx article_text, = util.unpack_tf_example(example, names_to_types) article_text = article_text # out_path = os.path.join(out_dir, in_dataset, 'article_%06d.txt' % example_idx) # with open(out_path, 'wb') as f: # f.write(article_text) return article_text
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] vocab_counter = collections.Counter() for dataset_split in dataset_splits: source_dir = os.path.join(FLAGS.data_root, FLAGS.dataset_name) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, article, abstracts, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] # groundtruth_summ_sent_tokens = [sent.strip().split() for sent in groundtruth_summary_text.strip().split('\n')] groundtruth_summ_sent_tokens = [[ token for token in abstract.strip().split() if token not in ['<s>', '</s>'] ] for abstract in abstracts] all_tokens = util.flatten_list_of_lists( article_sent_tokens) + util.flatten_list_of_lists( groundtruth_summ_sent_tokens) vocab_counter.update(all_tokens) print("Writing vocab file...") with open(os.path.join('logs', "vocab_" + FLAGS.dataset_name), 'w') as writer: for word, count in vocab_counter.most_common(VOCAB_SIZE): writer.write(word + ' ' + str(count) + '\n') print("Finished writing vocab file")
def text_generator(self, example_generator): """Generates article and abstract text from tf.Example. Args: example_generator: a generator of tf.Examples from file. See data.example_generator""" # i = 0 while True: e = next(example_generator) # e is a tf.Example try: names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string_list')] if self._hps.dataset_name == 'duc_2004': names_to_types[2] = ('summary_text', 'string_list') raw_article_sents, ssi, groundtruth_summary_sents = util.unpack_tf_example( e, names_to_types) groundtruth_summary_text = '\n'.join(groundtruth_summary_sents) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] article_text = ' '.join([' '.join(sent) for sent in article_sent_tokens]) if self._hps.dataset_name == 'duc_2004': abstract_sentences = [['<s> ' + sent.strip() + ' </s>' for sent in gt_summ_text.strip().split('\n')] for gt_summ_text in groundtruth_summary_text] abstract_sentences = [abs_sents[:max_dec_sents] for abs_sents in abstract_sentences] abstract_texts = [' '.join(abs_sents) for abs_sents in abstract_sentences] else: abstract_sentences = ['<s> ' + sent.strip() + ' </s>' for sent in groundtruth_summary_text.strip().split('\n')] abstract_sentences = abstract_sentences[:max_dec_sents] abstract_texts = [' '.join(abstract_sentences)] if 'doc_indices' not in e.features.feature or len(e.features.feature['doc_indices'].bytes_list.value) == 0: num_words = len(article_text.split()) doc_indices_text = '0 ' * num_words else: doc_indices_text = e.features.feature['doc_indices'].bytes_list.value[0] sentence_limit = 1 if self._hps.singles_and_pairs == 'singles' else 2 ssi = util.enforce_sentence_limit(ssi, sentence_limit) ssi = ssi[:max_dec_sents] ssi = util.make_ssi_chronological(ssi) except: logging.error('Failed to get article or abstract from example') raise if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1 logging.warning('Found an example with empty article text. Skipping it.\n*********************************************') elif len(article_text.strip().split()) < 3 and self._hps.skip_with_less_than_3: print('Article has less than 3 tokens, so skipping\n*********************************************') elif len(abstract_texts[0].strip().split()) < 3 and self._hps.skip_with_less_than_3: print('Abstract has less than 3 tokens, so skipping\n*********************************************') else: yield (article_text, abstract_texts, doc_indices_text, raw_article_sents, ssi)
def is_valid_example(e, is_original=True): abstract_texts = [] raw_article_sents = [] if not is_original: try: names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string')] raw_article_sents, ssi, groundtruth_summary_text = util.unpack_tf_example( e, names_to_types) if len(raw_article_sents) == 0: return False except ValueError: return False else: try: article_text = e.features.feature['article'].bytes_list.value[ 0] # the article text was saved under the key 'article' in the data files for abstract in e.features.feature['abstract'].bytes_list.value: abstract_texts.append( abstract ) # the abstract text was saved under the key 'abstract' in the data files if 'doc_indices' not in e.features.feature or len( e.features.feature['doc_indices'].bytes_list.value) == 0: num_words = len(article_text.split()) doc_indices_text = '0 ' * num_words else: doc_indices_text = e.features.feature[ 'doc_indices'].bytes_list.value[0] for sent in e.features.feature[ 'raw_article_sents'].bytes_list.value: raw_article_sents.append( sent ) # the abstract text was saved under the key 'abstract' in the data files except ValueError: return False if len( article_text ) == 0: # See https://github.com/abisee/pointer-generator/issues/1 return False return True
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents) else: gt_ssi, sys_ssi, ext_len = ssi_list[example_idx] if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) if gt_ssi != groundtruth_similar_source_indices_list: print( 'Warning: Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '' for ssi_idx, ssi in enumerate(sys_ssi): selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, FLAGS.batch_size, hps, self._vocab) decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 1000: min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>' if len(final_decoded_words) >= 100: break if example_idx < 1000: self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted(glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ('cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator(source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) for example_idx, example in enumerate(tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] if doc_indices is None: doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens) groundtruth_similar_source_indices_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, FLAGS.sentence_limit)
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): attn_vis_idx = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list sys_alp_list = groundtruth_article_lcs_paths_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) sys_ssi, sys_alp_list = util.replace_empty_ssis( sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list) else: gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[ example_idx] sys_alp_list = ssi_functions.list_labels_from_probs( sys_token_probs_list, FLAGS.tag_threshold) if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) # if gt_ssi != groundtruth_similar_source_indices_list: # raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '<u>System Summary</u><br><br>' for ssi_idx, ssi in enumerate(sys_ssi): # selected_article_lcs_paths = None selected_article_lcs_paths = sys_alp_list[ssi_idx] ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) selected_article_lcs_paths = [selected_article_lcs_paths] selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, selected_article_lcs_paths, FLAGS.batch_size, hps, self._vocab) original_article = batch.original_articles[0] # string original_abstract = batch.original_abstracts[0] # string article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs( original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string # article_withunks = data.show_art_oovs(original_article, self._vocab) # string # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string if FLAGS.first_intact and ssi_idx == 0: decoded_words = selected_article_text.strip().split() decoded_output = selected_article_text else: decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += highlighted_html + '<br>' if FLAGS.attn_vis and example_idx < 200: self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx ) # write info to .json file for visualization tool attn_vis_idx += 1 if len(final_decoded_words) >= 100: break gt_ssi_list, gt_alp_list = util.replace_empty_ssis( groundtruth_similar_source_indices_list, raw_article_sents, sys_alp_list=groundtruth_article_lcs_paths_list) highlight_html_gt = '<u>Reference Summary</u><br><br>' for ssi_idx, ssi in enumerate(gt_ssi_list): selected_article_lcs_paths = gt_alp_list[ssi_idx] try: ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) except: util.print_vars(ssi, example_idx, selected_article_lcs_paths) raise selected_raw_article_sents = util.reorder( raw_article_sents, ssi) if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [ groundtruth_summ_sent_tokens[ssi_idx] ] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_gt += highlighted_html + '<br>' if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) highlight_html_total = ssi_functions.put_html_in_two_columns( highlight_html_total, highlight_html_gt) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) # if example_idx % 100 == 0: # attn_dir = os.path.join(self._decode_dir, 'attn_vis_data') # attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later # if FLAGS.attn_vis: # self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def convert_article_to_lambdamart_features(ex): # example_idx += 1 # if num_instances != -1 and example_idx >= num_instances: # break example, example_idx, single_feat_len, pair_feat_len = ex print(example_idx) raw_article_sents, similar_source_indices_list, summary_text = util.unpack_tf_example(example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')] # sent_term_matrix = util.get_tfidf_matrix(raw_article_sents) article_text = ' '.join(raw_article_sents) sent_term_matrix = util.get_doc_substituted_tfidf_matrix(tfidf_vectorizer, raw_article_sents, article_text) doc_vector = np.mean(sent_term_matrix, axis=0) out_str = '' # ssi_idx_cur_inst_id = defaultdict(int) instances = [] if importance: importances = util.special_squash(util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents)) possible_pairs = [list(x) for x in list(itertools.combinations(list(range(len(raw_article_sents))), 2))] # all pairs possible_singles = [[i] for i in range(len(raw_article_sents))] possible_combinations = possible_pairs + possible_singles positives = [ssi for ssi in similar_source_indices_list] negatives = [ssi for ssi in possible_combinations if not (ssi in positives or ssi[::-1] in positives)] negative_pairs = [x for x in possible_pairs if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list)] negative_singles = [x for x in possible_singles if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list)] random_negative_pairs = np.random.permutation(len(negative_pairs)).tolist() random_negative_singles = np.random.permutation(len(negative_singles)).tolist() qid = example_idx for similar_source_indices in positives: # True sentence single/pair relevance = 1 features = get_features(similar_source_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, importances) if features is None: continue instances.append(Lambdamart_Instance(features, relevance, qid, similar_source_indices)) a=0 if balance: # False sentence single/pair is_pair = len(similar_source_indices) == 2 if is_pair: if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[random_negative_pairs.pop()] else: if len(random_negative_singles) == 0: continue negative_indices = negative_singles[random_negative_singles.pop()] neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, importances) if neg_features is None: continue instances.append(Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) if not balance: for negative_indices in negatives: neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, importances) if neg_features is None: continue instances.append(Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) else: mmr_all = util.calc_MMR_all(raw_article_sents, article_sent_tokens, summ_sent_tokens, None) # the size is (# of summary sents, # of article sents) possible_pairs = [list(x) for x in list(itertools.combinations(list(range(len(raw_article_sents))), 2))] # all pairs possible_singles = [[i] for i in range(len(raw_article_sents))] # negative_pairs = [x for x in possible_pairs if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list)] # negative_singles = [x for x in possible_singles if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list)] # # random_negative_pairs = np.random.permutation(len(negative_pairs)).tolist() # random_negative_singles = np.random.permutation(len(negative_singles)).tolist() all_combinations = list(itertools.product(possible_pairs + possible_singles, list(range(len(summ_sent_tokens))))) positives = [(similar_source_indices, summ_sent_idx) for summ_sent_idx, similar_source_indices in enumerate(similar_source_indices_list)] negatives = [(ssi, ssi_idx) for ssi, ssi_idx in all_combinations if not ((ssi, ssi_idx) in positives or (ssi[::-1], ssi_idx) in positives)] for similar_source_indices, ssi_idx in positives: # True sentence single/pair relevance = 1 qid = example_idx * 10 + ssi_idx features = get_features(similar_source_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, mmr_all[ssi_idx]) if features is None: continue # inst_id = ssi_idx_cur_inst_id[ssi_idx] instances.append(Lambdamart_Instance(features, relevance, qid, similar_source_indices)) # ssi_idx_cur_inst_id[ssi_idx] += 1 a=0 if balance: # False sentence single/pair is_pair = len(similar_source_indices) == 2 if is_pair: if len(random_negative_pairs) == 0: continue negative_indices = possible_pairs[random_negative_pairs.pop()] else: if len(random_negative_singles) == 0: continue negative_indices = possible_singles[random_negative_singles.pop()] neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len) if neg_features is None: continue neg_lambdamart_str = format_to_lambdamart([neg_features, neg_relevance, qid, negative_indices]) out_str += neg_lambdamart_str + '\n' if not balance: for negative_indices, ssi_idx in negatives: neg_relevance = 0 qid = example_idx * 10 + ssi_idx neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, mmr_all[ssi_idx]) if neg_features is None: continue # inst_id = ssi_idx_cur_inst_id[ssi_idx] instances.append(Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) # ssi_idx_cur_inst_id[ssi_idx] += 1 sorted_instances = sorted(instances, key=lambda x: (x.qid, x.source_indices)) assign_inst_ids(sorted_instances) if lr: return sorted_instances else: for instance in sorted_instances: lambdamart_str = format_to_lambdamart(instance, single_feat_len) out_str += lambdamart_str + '\n' # print out_str return out_str
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.singles_and_pairs == 'singles': FLAGS.sentence_limit = 1 else: FLAGS.sentence_limit = 2 if FLAGS.dataset_name == 'all': dataset_names = ['cnn_dm', 'xsum', 'duc_2004'] else: dataset_names = [FLAGS.dataset_name] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name source_dir = os.path.join(data_dir, dataset_name) if FLAGS.dataset_split == 'all': if dataset_name == 'duc_2004': dataset_splits = ['test'] else: # dataset_splits = ['val_test', 'test', 'val', 'train'] dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: if dataset_split == 'val_test': source_dataset_split = 'val' else: source_dataset_split = dataset_split source_files = sorted( glob.glob(source_dir + '/' + source_dataset_split + '*')) total = len(source_files) * 1000 example_generator = data.example_generator( source_dir + '/' + source_dataset_split + '*', True, False, should_check_valid=False) out_dir = os.path.join('data', 'bert', dataset_name, FLAGS.singles_and_pairs, 'input') util.create_dirs(out_dir) writer = open(os.path.join(out_dir, dataset_split) + '.tsv', 'wb') header_list = [ 'should_merge', 'sent1', 'sent2', 'example_idx', 'inst_id', 'ssi' ] writer.write(('\t'.join(header_list) + '\n').encode()) inst_id = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent, whitespace=True) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if dataset_name != 'duc_2004' or doc_indices is None or ( dataset_name != 'duc_2004' and len(doc_indices) != len( util.flatten_list_of_lists(article_sent_tokens))): doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] rel_sent_indices, _, _ = ssi_functions.get_rel_sent_indices( doc_indices, article_sent_tokens) similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) possible_pairs = [ x for x in list( itertools.combinations( list(range(len(raw_article_sents))), 2)) ] # all pairs possible_pairs = filter_pairs_by_sent_position( possible_pairs, rel_sent_indices=rel_sent_indices) possible_singles = [(i, ) for i in range(len(raw_article_sents))] positives = [ssi for ssi in similar_source_indices_list] if dataset_split == 'test' or dataset_split == 'val_test': if FLAGS.singles_and_pairs == 'singles': possible_combinations = possible_singles else: possible_combinations = possible_pairs + possible_singles negatives = [ ssi for ssi in possible_combinations if not (ssi in positives or ssi[::-1] in positives) ] for ssi_idx, ssi in enumerate(positives): if len(ssi) == 0: continue if chronological_ssi and len(ssi) >= 2: if ssi[0] > ssi[1]: ssi = (min(ssi), max(ssi)) writer.write( get_string_bert_example(raw_article_sents, ssi, 1, example_idx, inst_id).encode()) inst_id += 1 for ssi in negatives: writer.write( get_string_bert_example(raw_article_sents, ssi, 0, example_idx, inst_id).encode()) inst_id += 1 else: positive_sents = list( set(util.flatten_list_of_lists(positives))) negative_pairs = [ pair for pair in possible_pairs if not any(i in positive_sents for i in pair) ] negative_singles = [ sing for sing in possible_singles if not sing[0] in positive_sents ] random_negative_pairs = np.random.permutation( len(negative_pairs)).tolist() random_negative_singles = np.random.permutation( len(negative_singles)).tolist() for ssi in similar_source_indices_list: if len(ssi) == 0: continue if chronological_ssi and len(ssi) >= 2: if ssi[0] > ssi[1]: ssi = (min(ssi), max(ssi)) is_pair = len(ssi) == 2 writer.write( get_string_bert_example(raw_article_sents, ssi, 1, example_idx, inst_id).encode()) inst_id += 1 # False sentence single/pair if is_pair: if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[ random_negative_pairs.pop()] else: if len(random_negative_singles) == 0: continue negative_indices = negative_singles[ random_negative_singles.pop()] article_lcs_paths = None writer.write( get_string_bert_example(raw_article_sents, negative_indices, 0, example_idx, inst_id).encode()) inst_id += 1
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) start_time = time.time() np.random.seed(random_seed) if FLAGS.dataset_name == 'all': datasets = dataset_names else: datasets = [FLAGS.dataset_name] for dataset in datasets: coref_dir = os.path.join(FLAGS.coref_root, dataset) to_coref_dir = os.path.join(coref_dir, 'to_coref') corenlp_lists_dir = os.path.join(coref_dir, 'corenlp_lists') data_coref_dir = os.path.join(FLAGS.data_root, 'with_coref', dataset) util.create_dirs(to_coref_dir) util.create_dirs(corenlp_lists_dir) util.create_dirs(data_coref_dir) source_dir = os.path.join(FLAGS.data_root, dataset) if FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: source_files = sorted( glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in dataset or 'newsroom' in dataset or 'xsum' in dataset) else len(source_files) example_generator = data.example_generator( source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) if FLAGS.mode == 'prepare': corenlp_list = [] out_idx = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, article, abstract, doc_indices = util.unpack_tf_example( example, names_to_types) if raw_article_sents is None: continue raw_article = ' '.join(raw_article_sents) file_name = os.path.join( to_coref_dir, '%s_%06d.bin' % (dataset_split, out_idx)) with open(file_name, 'wb') as f: f.write(raw_article) corenlp_list.append(file_name) with open( os.path.join(corenlp_lists_dir, 'all_' + dataset_split + '.txt'), 'wb') as f: f.write('\n'.join(corenlp_list)) out_idx += 1 elif FLAGS.mode == 'create': process_coref_dir = os.path.join(coref_dir, 'processed') out_idx = 0 out_file_name = os.path.join( data_coref_dir, dataset_split + '_{:05d}.bin'.format(out_idx // 1000)) writer = open(os.path.join(out_file_name), 'wb') coref_files = sorted( glob.glob( os.path.join(process_coref_dir, dataset_split + '*'))) coref_dict = {} for c in coref_files: coref_dict[c.split('/')[-1].split('.json')[0]] = c print(len(coref_files), len(source_files)) for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, article, abstract, doc_indices = util.unpack_tf_example( example, names_to_types) if raw_article_sents is None: continue raw_article_sents = [ sent for sent in raw_article_sents if sent.strip() != '' ] if out_idx % 1000 == 0 and out_idx != 0: writer.close() out_file_name = os.path.join( data_coref_dir, dataset_split + '_{:05d}.bin'.format(out_idx // 1000)) writer = open(os.path.join(out_file_name), 'wb') # coref_file = os.path.join(process_coref_dir, 'test_%06d.bin.json' % example_idx) # coref_file = coref_files[out_idx] # matched_files = [name for name in coref_files if '%s_%06d.bin'%(dataset_split, out_idx) in name] file_name = '%s_%06d.bin' % (dataset_split, out_idx) if file_name in coref_dict: file_path = coref_dict[file_name] corefs = get_corefs(file_path) fixed_corefs = fix_trailing_apostrophe_s(corefs) corefs_relevant_info = remove_irrelevant(fixed_corefs) corefs_json = json.dumps(corefs_relevant_info) else: corefs_json = json.dumps([]) example.features.feature['corefs'].bytes_list.value.extend( [corefs_json]) tf_example = convert_data.make_example( article, abstract, doc_indices, raw_article_sents, corefs) convert_data.write_tf_example(example, writer) out_idx += 1 writer.close() # file_name = os.path.join(data_coref_dir, '%s_%06d.bin' % (dataset_split, example_idx)) # writer = open(file_name, 'wb') # coref_file = os.path.join(process_coref_dir, 'test_%06d.bin.json'%example_idx) # corefs = get_corefs(coref_file) # fixed_corefs = fix_trailing_apostrophe_s(corefs) # # corefs_relevant_info = remove_irrelevant(fixed_corefs) # corefs_json = json.dumps(corefs_relevant_info) # # example.features.feature['corefs'].bytes_list.value.extend([corefs_json]) # tf_example_str = example.SerializeToString() # str_len = len(tf_example_str) # writer.write(struct.pack('q', str_len)) # writer.write(struct.pack('%ds' % str_len, tf_example_str)) # # writer.close() util.print_execution_time(start_time)
def write_to_lambdamart_examples_to_file(ex): example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs = ex print(example_idx) # example_idx += 1 temp_in_path = os.path.join(temp_in_dir, '%06d.txt' % example_idx) if not FLAGS.start_over and os.path.exists(temp_in_path): return raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] if len(doc_indices) != len( util.flatten_list_of_lists(article_sent_tokens)): doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) rel_sent_indices, _, _ = get_rel_sent_indices(doc_indices, article_sent_tokens) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] # summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')] if FLAGS.dataset_name == 'duc_2004': first_k_indices = get_indices_of_first_k_sents_of_each_article( rel_sent_indices, FLAGS.first_k) else: first_k_indices = [idx for idx in range(len(raw_article_sents))] if importance: get_instances(example_idx, raw_article_sents, article_sent_tokens, corefs, rel_sent_indices, first_k_indices, temp_in_path, temp_out_path, single_feat_len, pair_feat_len, singles_and_pairs)
def convert_article_to_lambdamart_features(ex): # example_idx += 1 # if num_instances != -1 and example_idx >= num_instances: # break example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs, out_path = ex print(example_idx) raw_article_sents, similar_source_indices_list, summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] if len(doc_indices) != len( util.flatten_list_of_lists(article_sent_tokens)): doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) rel_sent_indices, _, _ = util.get_rel_sent_indices(doc_indices, article_sent_tokens) if FLAGS.singles_and_pairs == 'singles': sentence_limit = 1 else: sentence_limit = 2 similar_source_indices_list = util.enforce_sentence_limit( similar_source_indices_list, sentence_limit) summ_sent_tokens = [ sent.strip().split() for sent in summary_text.strip().split('\n') ] # sent_term_matrix = util.get_tfidf_matrix(raw_article_sents) article_text = ' '.join(raw_article_sents) sent_term_matrix = util.get_doc_substituted_tfidf_matrix( tfidf_vectorizer, raw_article_sents, article_text, pca) doc_vector = np.mean(sent_term_matrix, axis=0) out_str = '' # ssi_idx_cur_inst_id = defaultdict(int) instances = [] if importance: importances = util.special_squash( util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents, pca)) possible_pairs = [ x for x in list( itertools.combinations(list(range(len(raw_article_sents))), 2)) ] # all pairs if FLAGS.use_pair_criteria: possible_pairs = filter_pairs_by_criteria(raw_article_sents, possible_pairs, corefs) if FLAGS.sent_position_criteria: possible_pairs = filter_pairs_by_sent_position( possible_pairs, rel_sent_indices) possible_singles = [(i, ) for i in range(len(raw_article_sents))] possible_combinations = possible_pairs + possible_singles positives = [ssi for ssi in similar_source_indices_list] negatives = [ ssi for ssi in possible_combinations if not (ssi in positives or ssi[::-1] in positives) ] negative_pairs = [ x for x in possible_pairs if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list) ] negative_singles = [ x for x in possible_singles if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list) ] random_negative_pairs = np.random.permutation( len(negative_pairs)).tolist() random_negative_singles = np.random.permutation( len(negative_singles)).tolist() qid = example_idx for similar_source_indices in positives: # True sentence single/pair relevance = 1 features = get_features(similar_source_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if features is None: continue instances.append( Lambdamart_Instance(features, relevance, qid, similar_source_indices)) a = 0 if FLAGS.dataset_name == 'xsum' and FLAGS.special_xsum_balance: neg_relevance = 0 num_negative = 4 if FLAGS.singles_and_pairs == 'singles': num_neg_singles = num_negative num_neg_pairs = 0 else: num_neg_singles = num_negative / 2 num_neg_pairs = num_negative / 2 for _ in range(num_neg_singles): if len(random_negative_singles) == 0: continue negative_indices = negative_singles[ random_negative_singles.pop()] neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) for _ in range(num_neg_pairs): if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[ random_negative_pairs.pop()] neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) elif balance: # False sentence single/pair is_pair = len(similar_source_indices) == 2 if is_pair: if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[ random_negative_pairs.pop()] else: if len(random_negative_singles) == 0: continue negative_indices = negative_singles[ random_negative_singles.pop()] neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) if not balance: for negative_indices in negatives: neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) sorted_instances = sorted(instances, key=lambda x: (x.qid, x.source_indices)) assign_inst_ids(sorted_instances) if FLAGS.lr: return sorted_instances else: for instance in sorted_instances: lambdamart_str = format_to_lambdamart(instance, single_feat_len) out_str += lambdamart_str + '\n' with open(os.path.join(out_path, '%06d.txt' % example_idx), 'wb') as f: f.write(out_str)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) out_dir = os.path.join( os.path.expanduser('~') + '/data/kaiqiang_data', FLAGS.dataset_name) if FLAGS.mode == 'write': util.create_dirs(out_dir) if FLAGS.dataset_name == 'duc_2004': dataset_splits = ['test'] elif FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: if dataset_split == 'test': ssi_data_path = os.path.join( 'logs/%s_bert_both_sentemb_artemb_plushidden' % FLAGS.dataset_name, 'ssi.pkl') print(util.bcolors.OKGREEN + "Loading SSI from BERT at %s" % ssi_data_path + util.bcolors.ENDC) with open(ssi_data_path) as f: ssi_triple_list = pickle.load(f) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted( glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) out_document_path = os.path.join(out_dir, dataset_split + '.Ndocument') out_summary_path = os.path.join(out_dir, dataset_split + '.Nsummary') out_example_idx_path = os.path.join(out_dir, dataset_split + '.Nexampleidx') doc_writer = open(out_document_path, 'w') if dataset_split != 'test': sum_writer = open(out_summary_path, 'w') ex_idx_writer = open(out_example_idx_path, 'w') for example_idx, example in enumerate( tqdm(example_generator, total=total)): if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens) if dataset_split == 'test': if example_idx >= len(ssi_triple_list): raise Exception( 'Len of ssi list (%d) is less than number of examples (>=%d)' % (len(ssi_triple_list), example_idx)) ssi_length_extractive = ssi_triple_list[example_idx][2] if ssi_length_extractive > 1: a = 0 ssi = ssi_triple_list[example_idx][1] ssi = ssi[:ssi_length_extractive] groundtruth_similar_source_indices_list = ssi else: groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) for ssi_idx, ssi in enumerate( groundtruth_similar_source_indices_list): if len(ssi) == 0: continue my_article = ' '.join(util.reorder(raw_article_sents, ssi)) doc_writer.write(my_article + '\n') if dataset_split != 'test': sum_writer.write(groundtruth_summ_sents[0][ssi_idx] + '\n') ex_idx_writer.write(str(example_idx) + '\n') elif FLAGS.mode == 'evaluate': summary_dir = '/home/logan/data/kaiqiang_data/logan_ACL/trained_on_' + FLAGS.train_dataset + '/' + FLAGS.dataset_name out_summary_path = os.path.join(summary_dir, 'test' + 'Summary.txt') out_example_idx_path = os.path.join(out_dir, 'test' + '.Nexampleidx') decode_dir = 'logs/kaiqiang_%s_trainedon%s' % (FLAGS.dataset_name, FLAGS.train_dataset) rouge_ref_dir = os.path.join(decode_dir, 'reference') rouge_dec_dir = os.path.join(decode_dir, 'decoded') util.create_dirs(rouge_ref_dir) util.create_dirs(rouge_dec_dir) def num_lines_in_file(file_path): with open(file_path) as f: num_lines = sum(1 for line in f) return num_lines def process_example(sents, ex_idx, groundtruth_summ_sents): final_decoded_words = [] for sent in sents: final_decoded_words.extend(sent.split(' ')) rouge_functions.write_for_rouge(groundtruth_summ_sents, None, ex_idx, rouge_ref_dir, rouge_dec_dir, decoded_words=final_decoded_words, log=False) num_lines_summary = num_lines_in_file(out_summary_path) num_lines_example_indices = num_lines_in_file(out_example_idx_path) if num_lines_summary != num_lines_example_indices: raise Exception( 'Num lines summary != num lines example indices: (%d, %d)' % (num_lines_summary, num_lines_example_indices)) source_dir = os.path.join(data_dir, FLAGS.dataset_name) example_generator = data.example_generator(source_dir + '/' + 'test' + '*', True, False, should_check_valid=False) sum_writer = open(out_summary_path) ex_idx_writer = open(out_example_idx_path) prev_ex_idx = 0 sents = [] for line_idx in tqdm(range(num_lines_summary)): line = sum_writer.readline() ex_idx = int(ex_idx_writer.readline()) if ex_idx == prev_ex_idx: sents.append(line) else: example = example_generator.next() raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] process_example(sents, ex_idx, groundtruth_summ_sents) prev_ex_idx = ex_idx sents = [line] example = example_generator.next() raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] process_example(sents, ex_idx, groundtruth_summ_sents) print("Now starting ROUGE eval...") if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 results_dict = rouge_functions.rouge_eval(rouge_ref_dir, rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, decode_dir) else: raise Exception('mode flag was not evaluate or write.')
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings = ex print(example_idx) # example_idx += 1 qid = example_idx raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] enforced_groundtruth_ssi_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sent_tokens = [] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if FLAGS.upper_bound: replaced_ssi_list = util.replace_empty_ssis( enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists( replaced_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) else: summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, \ article_lcs_paths_list, token_probs_list = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx, qid_ssi_to_token_scores_and_mappings) similar_source_indices_list_trunc = similar_source_indices_list[: ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[: ssi_length_extractive] if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html_trunc ] if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0: lcs_paths_list_param = copy.deepcopy(article_lcs_paths_list) else: lcs_paths_list_param = None extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices, lcs_paths_list=lcs_paths_list_param) # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_ssi_list, gt_lcs_paths_list, gt_article_lcs_paths_list, gt_smooth_article_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=gt_lcs_paths_list, article_lcs_paths_list=gt_smooth_article_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html # all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive, token_probs_list)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.dataset_name == 'all': dataset_names = ['cnn_dm', 'xsum', 'duc_2004'] else: dataset_names = [FLAGS.dataset_name] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name source_dir = os.path.join(data_dir, dataset_name) if FLAGS.dataset_split == 'all': if dataset_name == 'duc_2004': dataset_splits = ['test'] else: dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: source_files = sorted( glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 example_generator = data.example_generator( source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) out_dir = os.path.join('data', 'bert', dataset_name, 'article_embeddings', 'input_article') util.create_dirs(out_dir) writer = open(os.path.join(out_dir, dataset_split) + '.tsv', 'wb') # writer.write('\t'.join(['should_merge', 'sent1', 'sent2', 'example_idx', 'ssi']) + '\n') inst_id = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) # article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] # groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] # if doc_indices is None or (dataset_name != 'duc_2004' and len(doc_indices) != len(util.flatten_list_of_lists(article_sent_tokens))): # doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens)) # doc_indices = [int(doc_idx) for doc_idx in doc_indices] # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens) # similar_source_indices_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, FLAGS.sentence_limit) article = ' '.join(raw_article_sents) writer.write(article + '\n')
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.dataset_name == 'all': dataset_names = ['cnn_dm', 'xsum', 'duc_2004'] else: dataset_names = [FLAGS.dataset_name] if not os.path.exists(plot_data_file): all_lists_of_histogram_pairs = [] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name if dataset_name == 'duc_2004': dataset_splits = ['test'] elif FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] ssi_list = [] for dataset_split in dataset_splits: ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name, dataset_split + '_ssi.pkl') with open(ssi_path) as f: ssi_list.extend(pickle.load(f)) if FLAGS.dataset_name == 'duc_2004': for abstract_idx in [1, 2, 3]: ssi_path = os.path.join( ssi_dir, FLAGS.dataset_name, dataset_split + '_ssi_' + str(abstract_idx) + '.pkl') with open(ssi_path) as f: temp_ssi_list = pickle.load(f) ssi_list.extend(temp_ssi_list) ssi_2d = util.flatten_list_of_lists(ssi_list) num_extracted = [ len(ssi) for ssi in util.flatten_list_of_lists(ssi_list) ] hist_num_extracted = np.histogram(num_extracted, bins=6, range=(0, 5)) print(hist_num_extracted) print('Histogram of number of sentences merged: ' + util.hist_as_pdf_str(hist_num_extracted)) distances = [ abs(ssi[0] - ssi[1]) for ssi in ssi_2d if len(ssi) >= 2 ] print('Distance between sentences (mean, median): ', np.mean(distances), np.median(distances)) hist_dist = np.histogram(distances, bins=max(distances)) print('Histogram of distances: ' + util.hist_as_pdf_str(hist_dist)) summ_sent_idx_to_number_of_source_sents = [[], [], [], [], [], [], [], [], [], []] for ssi in ssi_list: for summ_sent_idx, source_indices in enumerate(ssi): if len(source_indices) == 0 or summ_sent_idx >= len( summ_sent_idx_to_number_of_source_sents): continue num_sents = len(source_indices) if num_sents > 2: num_sents = 2 summ_sent_idx_to_number_of_source_sents[ summ_sent_idx].append(num_sents) print( "Number of source sents for summary sentence indices (Is the first summary sent more likely to match with a singleton or a pair?):" ) for summ_sent_idx, list_of_numbers_of_source_sents in enumerate( summ_sent_idx_to_number_of_source_sents): if len(list_of_numbers_of_source_sents) == 0: percent_singleton = 0. else: percent_singleton = list_of_numbers_of_source_sents.count( 1) * 1. / len(list_of_numbers_of_source_sents) percent_pair = list_of_numbers_of_source_sents.count( 2) * 1. / len(list_of_numbers_of_source_sents) print str(percent_singleton) + '\t', print '' for summ_sent_idx, list_of_numbers_of_source_sents in enumerate( summ_sent_idx_to_number_of_source_sents): if len(list_of_numbers_of_source_sents) == 0: percent_pair = 0. else: percent_singleton = list_of_numbers_of_source_sents.count( 1) * 1. / len(list_of_numbers_of_source_sents) percent_pair = list_of_numbers_of_source_sents.count( 2) * 1. / len(list_of_numbers_of_source_sents) print str(percent_pair) + '\t', print '' primary_pos = [ssi[0] for ssi in ssi_2d if len(ssi) >= 1] secondary_pos = [ssi[1] for ssi in ssi_2d if len(ssi) >= 2] all_pos = [max(ssi) for ssi in ssi_2d if len(ssi) >= 1] # if FLAGS.dataset_name != 'duc_2004': # plot_positions(primary_pos, secondary_pos, all_pos) if FLAGS.dataset_split == 'all': glob_string = '*.bin' else: glob_string = dataset_splits[0] print('Loading TFIDF vectorizer') with open(tfidf_vec_path, 'rb') as f: tfidf_vectorizer = pickle.load(f) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted( glob.glob(source_dir + '/' + glob_string + '*')) total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + glob_string + '*', True, False, should_check_valid=False) all_possible_singles = 0 all_possible_pairs = [0] all_filtered_pairs = 0 all_all_combinations = 0 all_ssi_pairs = [0] ssi_pairs_with_shared_coref = [0] ssi_pairs_with_shared_word = [0] ssi_pairs_with_either_coref_or_word = [0] all_pairs_with_shared_coref = [0] all_pairs_with_shared_word = [0] all_pairs_with_either_coref_or_word = [0] actual_total = [0] rel_positions_primary = [] rel_positions_secondary = [] rel_positions_all = [] sent_lens = [] all_sent_lens = [] all_pos = [] y = [] normalized_positions_primary = [] normalized_positions_secondary = [] all_normalized_positions_primary = [] all_normalized_positions_secondary = [] normalized_positions_singles = [] normalized_positions_pairs_first = [] normalized_positions_pairs_second = [] primary_pos_duc = [] secondary_pos_duc = [] all_pos_duc = [] all_distances = [] distances_duc = [] tfidf_similarities = [] all_tfidf_similarities = [] average_mmrs = [] all_average_mmrs = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): # def process(example_idx_example): # # print '0' # example = example_idx_example if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] article_text = ' '.join(raw_article_sents) groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] rel_sent_indices, doc_sent_indices, doc_sent_lens = preprocess_for_lambdamart_no_flags.get_rel_sent_indices( doc_indices, article_sent_tokens) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) sent_term_matrix = util.get_doc_substituted_tfidf_matrix( tfidf_vectorizer, raw_article_sents, article_text) sents_similarities = util.cosine_similarity( sent_term_matrix, sent_term_matrix) importances = util.special_squash( util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents)) if FLAGS.dataset_name == 'duc_2004': first_k_indices = lambdamart_scores_to_summaries.get_indices_of_first_k_sents_of_each_article( rel_sent_indices, FLAGS.first_k) else: first_k_indices = [ idx for idx in range(len(raw_article_sents)) ] article_indices = list(range(len(raw_article_sents))) possible_pairs = [ x for x in list(itertools.combinations(article_indices, 2)) ] # all pairs # # # filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_criteria(raw_article_sents, possible_pairs, corefs) # if FLAGS.dataset_name == 'duc_2004': # filtered_possible_pairs = [x for x in list(itertools.combinations(first_k_indices, 2))] # all pairs # else: # filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_sent_position(possible_pairs) # # removed_pairs = list(set(possible_pairs) - set(filtered_possible_pairs)) # possible_singles = [(i,) for i in range(len(raw_article_sents))] # all_combinations = filtered_possible_pairs + possible_singles # # all_possible_singles += len(possible_singles) # all_possible_pairs[0] += len(possible_pairs) # all_filtered_pairs += len(filtered_possible_pairs) # all_all_combinations += len(all_combinations) # for ssi in groundtruth_similar_source_indices_list: # if len(ssi) > 0: # idx = rel_sent_indices[ssi[0]] # rel_positions_primary.append(idx) # rel_positions_all.append(idx) # if len(ssi) > 1: # idx = rel_sent_indices[ssi[1]] # rel_positions_secondary.append(idx) # rel_positions_all.append(idx) # # # # coref_pairs = preprocess_for_lambdamart_no_flags.get_coref_pairs(corefs) # # DO OVER LAP PAIRS BETTER # overlap_pairs = preprocess_for_lambdamart_no_flags.filter_by_overlap(article_sent_tokens, possible_pairs) # either_coref_or_word = list(set(list(coref_pairs) + overlap_pairs)) # # for ssi in groundtruth_similar_source_indices_list: # if len(ssi) == 2: # all_ssi_pairs[0] += 1 # do_share_coref = ssi in coref_pairs # do_share_words = ssi in overlap_pairs # if do_share_coref: # ssi_pairs_with_shared_coref[0] += 1 # if do_share_words: # ssi_pairs_with_shared_word[0] += 1 # if do_share_coref or do_share_words: # ssi_pairs_with_either_coref_or_word[0] += 1 # all_pairs_with_shared_coref[0] += len(coref_pairs) # all_pairs_with_shared_word[0] += len(overlap_pairs) # all_pairs_with_either_coref_or_word[0] += len(either_coref_or_word) if FLAGS.dataset_name == 'duc_2004': primary_pos_duc.extend([ rel_sent_indices[ssi[0]] for ssi in groundtruth_similar_source_indices_list if len(ssi) >= 1 ]) secondary_pos_duc.extend([ rel_sent_indices[ssi[1]] for ssi in groundtruth_similar_source_indices_list if len(ssi) >= 2 ]) all_pos_duc.extend([ max([rel_sent_indices[sent_idx] for sent_idx in ssi]) for ssi in groundtruth_similar_source_indices_list if len(ssi) >= 1 ]) for ssi in groundtruth_similar_source_indices_list: for sent_idx in ssi: sent_lens.append(len(article_sent_tokens[sent_idx])) if len(ssi) >= 1: orig_val = ssi[0] vals_to_add = get_integral_values_for_histogram( orig_val, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_primary.extend(vals_to_add) if len(ssi) >= 2: orig_val = ssi[1] vals_to_add = get_integral_values_for_histogram( orig_val, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_secondary.extend(vals_to_add) if FLAGS.dataset_name == 'duc_2004': distances_duc.append( abs(rel_sent_indices[ssi[1]] - rel_sent_indices[ssi[0]])) tfidf_similarities.append(sents_similarities[ssi[0], ssi[1]]) average_mmrs.append( (importances[ssi[0]] + importances[ssi[1]]) / 2) for ssi in groundtruth_similar_source_indices_list: if len(ssi) == 1: orig_val = ssi[0] vals_to_add = get_integral_values_for_histogram( orig_val, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_singles.extend(vals_to_add) if len(ssi) >= 2: if doc_sent_indices[ssi[0]] != doc_sent_indices[ ssi[1]]: continue orig_val_first = min(ssi[0], ssi[1]) vals_to_add = get_integral_values_for_histogram( orig_val_first, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_pairs_first.extend(vals_to_add) orig_val_second = max(ssi[0], ssi[1]) vals_to_add = get_integral_values_for_histogram( orig_val_second, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_pairs_second.extend(vals_to_add) # all_normalized_positions_primary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(single[0], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for single in possible_singles])) # all_normalized_positions_secondary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(pair[1], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for pair in possible_pairs])) all_sent_lens.extend( [len(sent) for sent in article_sent_tokens]) all_distances.extend([ abs(rel_sent_indices[pair[1]] - rel_sent_indices[pair[0]]) for pair in possible_pairs ]) all_tfidf_similarities.extend([ sents_similarities[pair[0], pair[1]] for pair in possible_pairs ]) all_average_mmrs.extend([ (importances[pair[0]] + importances[pair[1]]) / 2 for pair in possible_pairs ]) # if FLAGS.dataset_name == 'duc_2004': # rel_pos_single = [rel_sent_indices[single[0]] for single in possible_singles] # rel_pos_pair = [[rel_sent_indices[pair[0]], rel_sent_indices[pair[1]]] for pair in possible_pairs] # all_pos.extend(rel_pos_single) # all_pos.extend([max(pair) for pair in rel_pos_pair]) # else: # all_pos.extend(util.flatten_list_of_lists(possible_singles)) # all_pos.extend([max(pair) for pair in possible_pairs]) # y.extend([1 if single in groundtruth_similar_source_indices_list else 0 for single in possible_singles]) # y.extend([1 if pair in groundtruth_similar_source_indices_list else 0 for pair in possible_pairs]) # actual_total[0] += 1 # # p = Pool(144) # # list(tqdm(p.imap(process, example_generator), total=total)) # # # print 'Possible_singles\tPossible_pairs\tFiltered_pairs\tAll_combinations: \n%.2f\t%.2f\t%.2f\t%.2f' % (all_possible_singles*1./actual_total, \ # # all_possible_pairs*1./actual_total, all_filtered_pairs*1./actual_total, all_all_combinations*1./actual_total) # # # # # print 'Relative positions of groundtruth source sentences in document:\nPrimary\tSecondary\tBoth\n%.2f\t%.2f\t%.2f' % (np.mean(rel_positions_primary), np.mean(rel_positions_secondary), np.mean(rel_positions_all)) # # # # print 'SSI Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \ # # % (ssi_pairs_with_shared_coref[0]*100./all_ssi_pairs[0], ssi_pairs_with_shared_word[0]*100./all_ssi_pairs[0], ssi_pairs_with_either_coref_or_word[0]*100./all_ssi_pairs[0]) # # print 'All Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \ # # % (all_pairs_with_shared_coref[0]*100./all_possible_pairs[0], all_pairs_with_shared_word[0]*100./all_possible_pairs[0], all_pairs_with_either_coref_or_word[0]*100./all_possible_pairs[0]) # # # hist_all_pos = np.histogram(all_pos, bins=max(all_pos)+1) # # print 'Histogram of all sent positions: ', util.hist_as_pdf_str(hist_all_pos) # # min_sent_len = min(sent_lens) # # hist_sent_lens = np.histogram(sent_lens, bins=max(sent_lens)-min_sent_len+1) # # print 'min, max sent lens:', min_sent_len, max(sent_lens) # # print 'Histogram of sent lens: ', util.hist_as_pdf_str(hist_sent_lens) # # min_all_sent_len = min(all_sent_lens) # # hist_all_sent_lens = np.histogram(all_sent_lens, bins=max(all_sent_lens)-min_all_sent_len+1) # # print 'min, max all sent lens:', min_all_sent_len, max(all_sent_lens) # # print 'Histogram of all sent lens: ', util.hist_as_pdf_str(hist_all_sent_lens) # # # print 'Pearsons r, p value', pearsonr(all_pos, y) # # fig, ax1 = plt.subplots(nrows=1) # # plt.scatter(all_pos, y) # # pp = PdfPages(os.path.join('stuff/plots', FLAGS.dataset_name + '_position_scatter.pdf')) # # plt.savefig(pp, format='pdf',bbox_inches='tight') # # plt.show() # # pp.close() # # # if FLAGS.dataset_name == 'duc_2004': # # plot_positions(primary_pos_duc, secondary_pos_duc, all_pos_duc) # # normalized_positions_all = normalized_positions_primary + normalized_positions_secondary # # plot_histogram(normalized_positions_primary, num_bins=100) # # plot_histogram(normalized_positions_secondary, num_bins=100) # # plot_histogram(normalized_positions_all, num_bins=100) # # sent_lens_together = [sent_lens, all_sent_lens] # # plot_histogram(sent_lens_together, pdf=True, start_at_0=True, max_val=70) # # if FLAGS.dataset_name == 'duc_2004': # distances = distances_duc # sent_distances_together = [distances, all_distances] # # plot_histogram(sent_distances_together, pdf=True, start_at_0=True, max_val=100) # # tfidf_similarities_together = [tfidf_similarities, all_tfidf_similarities] # # plot_histogram(tfidf_similarities_together, pdf=True, num_bins=100) # # average_mmrs_together = [average_mmrs, all_average_mmrs] # # plot_histogram(average_mmrs_together, pdf=True, num_bins=100) # # normalized_positions_primary_together = [normalized_positions_primary, bin_values] # normalized_positions_secondary_together = [normalized_positions_secondary, bin_values] # # plot_histogram(normalized_positions_primary_together, pdf=True, num_bins=100) # # plot_histogram(normalized_positions_secondary_together, pdf=True, num_bins=100) # # # list_of_hist_pairs = [ # { # 'lst': normalized_positions_primary_together, # 'pdf': True, # 'num_bins': 100, # 'y_lim': 3.9, # 'y_label': FLAGS.dataset_name, # 'x_label': 'Sent position (primary)' # }, # { # 'lst': normalized_positions_secondary_together, # 'pdf': True, # 'num_bins': 100, # 'y_lim': 3.9, # 'x_label': 'Sent position (secondary)' # }, # { # 'lst': sent_distances_together, # 'pdf': True, # 'start_at_0': True, # 'max_val': 100, # 'x_label': 'Sent distance' # }, # { # 'lst': sent_lens_together, # 'pdf': True, # 'start_at_0': True, # 'max_val': 70, # 'x_label': 'Sent length' # }, # { # 'lst': average_mmrs_together, # 'pdf': True, # 'num_bins': 100, # 'x_label': 'Average TF-IDF importance' # } # ] normalized_positions_pairs_together = [ normalized_positions_pairs_first, normalized_positions_pairs_second ] list_of_hist_pairs = [ { 'lst': [normalized_positions_singles], 'pdf': True, 'num_bins': 100, # 'y_lim': 3.9, 'x_lim': 1.0, 'y_label': FLAGS.dataset_name, 'x_label': 'Sent Position (Singles)', 'legend_labels': ['Primary'] }, { 'lst': normalized_positions_pairs_together, 'pdf': True, 'num_bins': 100, # 'y_lim': 3.9, 'x_lim': 1.0, 'x_label': 'Sent Position (Pairs)', 'legend_labels': ['Primary', 'Secondary'] } ] all_lists_of_histogram_pairs.append(list_of_hist_pairs) with open(plot_data_file, 'w') as f: cPickle.dump(all_lists_of_histogram_pairs, f) else: with open(plot_data_file) as f: all_lists_of_histogram_pairs = cPickle.load(f) plot_histograms(all_lists_of_histogram_pairs)
def load_and_evaluate_example(ex): example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs = ex print(example_idx) # example_idx += 1 raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs = util.unpack_tf_example(example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]] # summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')] temp_in_path = os.path.join(temp_in_dir, '%06d.txt' % example_idx) temp_out_path = os.path.join(temp_out_dir, '%06d.txt' % example_idx) if importance: summary_sents, similar_source_indices_list, summary_sents_for_html = generate_summary_importance(raw_article_sents, article_sent_tokens, corefs, temp_in_path, temp_out_path, single_feat_len, pair_feat_len, singles_and_pairs)
def text_generator(self, example_generator): """Generates article and abstract text from tf.Example. Args: example_generator: a generator of tf.Examples from file. See data.example_generator""" # i = 0 while True: # i += 1 e = next(example_generator) # e is a tf.Example abstract_texts = [] raw_article_sents = [] # if self._hps.pg_mmr or '_sent' in self._hps.dataset_name: try: # names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string'), ('corefs', 'json'), ('article_lcs_paths_list', 'delimited_list_of_list_of_lists')] names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string_list'), ('corefs', 'json'), ('article_lcs_paths_list', 'delimited_list_of_list_of_lists')] if self._hps.dataset_name == 'duc_2004': names_to_types[2] = ('summary_text', 'string_list') # raw_article_sents, ssi, groundtruth_summary_text, corefs, article_lcs_paths_list = util.unpack_tf_example( # e, names_to_types) raw_article_sents, ssi, groundtruth_summary_sents, corefs, article_lcs_paths_list = util.unpack_tf_example( e, names_to_types) groundtruth_summary_text = '\n'.join(groundtruth_summary_sents) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] article_text = ' '.join( [' '.join(sent) for sent in article_sent_tokens]) if self._hps.dataset_name == 'duc_2004': abstract_sentences = [[ '<s> ' + sent.strip() + ' </s>' for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] abstract_sentences = [ abs_sents[:max_dec_sents] for abs_sents in abstract_sentences ] abstract_texts = [ ' '.join(abs_sents) for abs_sents in abstract_sentences ] else: abstract_sentences = [ '<s> ' + sent.strip() + ' </s>' for sent in groundtruth_summary_text.strip().split('\n') ] abstract_sentences = abstract_sentences[:max_dec_sents] abstract_texts = [' '.join(abstract_sentences)] if 'doc_indices' not in e.features.feature or len( e.features.feature['doc_indices'].bytes_list.value ) == 0: num_words = len(article_text.split()) doc_indices_text = '0 ' * num_words else: doc_indices_text = e.features.feature[ 'doc_indices'].bytes_list.value[0] sentence_limit = 1 if self._hps.singles_and_pairs == 'singles' else 2 ssi = util.enforce_sentence_limit(ssi, sentence_limit) ssi = ssi[:max_dec_sents] article_lcs_paths_list = util.enforce_sentence_limit( article_lcs_paths_list, sentence_limit) article_lcs_paths_list = article_lcs_paths_list[:max_dec_sents] ssi, article_lcs_paths_list = util.make_ssi_chronological( ssi, article_lcs_paths_list) except: logging.error('Failed to get article or abstract from example') raise # continue # else: # try: # article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files # for abstract in e.features.feature['abstract'].bytes_list.value: # abstract_texts.append(abstract) # the abstract text was saved under the key 'abstract' in the data files # if 'doc_indices' not in e.features.feature or len(e.features.feature['doc_indices'].bytes_list.value) == 0: # num_words = len(article_text.split()) # doc_indices_text = '0 ' * num_words # else: # doc_indices_text = e.features.feature['doc_indices'].bytes_list.value[0] # for sent in e.features.feature['raw_article_sents'].bytes_list.value: # raw_article_sents.append(sent) # the abstract text was saved under the key 'abstract' in the data files # ssi = None # article_lcs_paths_list = None # except ValueError: # logging.error('Failed to get article or abstract from example\n*********************************************') # raise # # continue if len( article_text ) == 0: # See https://github.com/abisee/pointer-generator/issues/1 logging.warning( 'Found an example with empty article text. Skipping it.\n*********************************************' ) elif len(article_text.strip().split() ) < 3 and self._hps.skip_with_less_than_3: print( 'Article has less than 3 tokens, so skipping\n*********************************************' ) elif len(abstract_texts[0].strip().split() ) < 3 and self._hps.skip_with_less_than_3: print( 'Abstract has less than 3 tokens, so skipping\n*********************************************' ) else: # print i yield (article_text, abstract_texts, doc_indices_text, raw_article_sents, ssi, article_lcs_paths_list)
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.singles_and_pairs == 'both': FLAGS.exp_name = FLAGS.exp_name + '_both' exp_name = _exp_name + '_both' dataset_articles = _dataset_articles else: FLAGS.exp_name = FLAGS.exp_name + '_singles' exp_name = _exp_name + '_singles' dataset_articles = _dataset_articles + '_singles' my_log_dir = os.path.join(log_dir, FLAGS.ssi_exp_name) print('Running statistics on %s' % FLAGS.exp_name) if FLAGS.dataset_name != "": FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*') if not os.path.exists(os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0: print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name))) convert_data.process_dataset(FLAGS.dataset_name) logging.set_verbosity(logging.INFO) # choose what level of logging you want logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name FLAGS.actual_log_root = FLAGS.log_root FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode!='decode': raise Exception("The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input'] hps_dict = {} for key,val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict) tf.set_random_seed(113) # a seed value for randomness decode_model_hps = hps._replace( max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) start_time = time.time() np.random.seed(random_seed) source_dir = os.path.join(data_dir, dataset_articles) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) with open(os.path.join(my_log_dir, 'ssi.pkl')) as f: ssi_list = pickle.load(f) total = len(source_files) * 1000 if 'cnn' or 'newsroom' in dataset_articles else len(source_files) example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) for example_idx, example in enumerate(tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] gt_ssi, sys_ssi, ext_len = ssi_list[example_idx] if gt_ssi != groundtruth_similar_source_indices_list: raise Exception('Example %d has different groundtruth source indices: '%example_idx + str( groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if len(groundtruth_summ_sents) == len(groundtruth_similar_source_indices_list): tqdm.write('Example %d has different len groundtruth source indices from len summ sents: '%example_idx + str( groundtruth_similar_source_indices_list) + ' || ' + str(groundtruth_summ_sents)) a=0 print('done') a=0
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, _, _ = ex print(example_idx) # example_idx += 1 qid = example_idx raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] enforced_groundtruth_ssi_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if FLAGS.upper_bound: replaced_ssi_list = util.replace_empty_ssis( enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists( replaced_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) elif FLAGS.lead: lead_ssi_list = [(idx, ) for idx in list( range(util.average_sents_for_dataset[FLAGS.dataset_name]))] lead_ssi_list = lead_ssi_list[:len( raw_article_sents )] # make sure the sentence indices don't go past the total number of sentences in the article selected_article_sent_indices = util.flatten_list_of_lists( lead_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = lead_ssi_list ssi_length_extractive = len(similar_source_indices_list) else: summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary( article_sent_tokens, qid_ssi_to_importances, example_idx) similar_source_indices_list_trunc = similar_source_indices_list[: ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[: ssi_length_extractive] if example_idx <= 100: summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html_trunc ] extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices) # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) util.create_dirs(processed_root) # if not os.path.exists(os.path.join(raw_root, 'reference', 'summaries.txt')): util.create_dirs(os.path.join(raw_root, 'reference')) util.create_dirs(os.path.join(processed_root, 'article')) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted( glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator(source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) if preprocess_article_and_human_summaries: writer = open(os.path.join(raw_root, 'reference', 'summaries.txt'), 'w') writer_article = open( os.path.join(processed_root, 'article', 'articles.txt'), 'w') writer_tokenized_article = open( os.path.join(processed_root, 'article', 'articles_tokenized.txt'), 'w') reference_articles = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) groundtruth_summ_sents = [ util.unfix_bracket_tokens_in_sent(sent.strip()) for sent in groundtruth_summary_text.strip().split('\n') ] writer.write('\t'.join(groundtruth_summ_sents) + '\n') reference_article = '\t'.join([ util.unfix_bracket_tokens_in_sent(sent.strip()) for sent in raw_article_sents ]) reference_articles.append(reference_article) pretty_reference_article = fix_punctuations(reference_article) writer_article.write(pretty_reference_article + '\n') writer_tokenized_article.write(reference_article + '\n') writer.close() for system in systems: print('Processing ' + system + '...') raw_dir = os.path.join(raw_root, system) processed_dir = os.path.join(processed_root, system) util.create_dirs(processed_dir) if system == 'reference': with open(os.path.join(raw_dir, 'summaries.txt')) as f: with open(os.path.join(processed_dir, 'summaries.txt'), 'w') as writer: text = f.read() pretty_reference_summaries = fix_punctuations(text) writer.write(pretty_reference_summaries) reference_summaries = [ summ.strip() for summ in text.split('\n') if summ.strip() != '' ] with open( os.path.join(processed_dir, 'summaries_tokenized.txt'), 'w') as writer_tokenized: writer_tokenized.write(text + '\n') elif system == 'abs-rl-rerank': decoded_files = sorted( glob.glob( os.path.join(raw_dir, 'rnn-ext_abs_rl_rerank', 'decoded', '*.dec'))) sys_ref_files = sorted( glob.glob(os.path.join(raw_dir, 'reference', '*.ref'))) summaries = [] for file in decoded_files: with open(file) as f: text = f.read() text = util.unfix_bracket_tokens_in_sent(text) summary_sents = text.split('\n') summaries.append('\t'.join(summary_sents)) sys_ref_summaries = [] for file in sys_ref_files: with open(file) as f: text = f.read() text = util.unfix_bracket_tokens_in_sent(text) summary_sents = text.split('\n') sys_ref_summaries.append('\t'.join(summary_sents)) reordered_summaries = reorder_list_like(summaries, sys_ref_summaries, reference_summaries) with open(os.path.join(processed_dir, 'summaries.txt'), 'w') as writer: with open( os.path.join(processed_dir, 'summaries_tokenized.txt'), 'w') as writer_tokenized: for summ in reordered_summaries: writer_tokenized.write(summ + '\n') writer.write(fix_punctuations(summ) + '\n') elif system == 'pg': decoded_files = sorted( glob.glob( os.path.join(raw_dir, 'pointer-gen-cov', '*_decoded.txt'))) summaries = [] for file in tqdm(decoded_files): with open(file) as f: summary_sents = f.read().split('\n') summaries.append('\t'.join(summary_sents)) ref_files = sorted( glob.glob(os.path.join(raw_dir, 'reference', '*_reference.txt'))) sys_ref_summaries = [] for file in tqdm(ref_files): with open(file) as f: summary_sents = f.read().split('\n') sys_ref_summaries.append('\t'.join(summary_sents)) reordered_summaries = reorder_list_like(summaries, sys_ref_summaries, reference_summaries) with open(os.path.join(processed_dir, 'summaries.txt'), 'w') as writer: with open( os.path.join(processed_dir, 'summaries_tokenized.txt'), 'w') as writer_tokenized: for summ in reordered_summaries: writer_tokenized.write(summ + '\n') writer.write(fix_punctuations(summ) + '\n') elif system == 'bottom-up': with open( os.path.join(raw_dir, 'bottom_up_cnndm_015_threshold.out')) as f: text_with_slash_t = f.read() text_with_slash_t = util.unfix_bracket_tokens_in_sent( text_with_slash_t) text_tab_separated = slash_t_to_tab_separated( text_with_slash_t) summaries = [ summ.strip() for summ in text_tab_separated.split('\n') if summ.strip() != '' ] with open(os.path.join(raw_dir, 'test.txt.tgt.tagged.shuf.noslash')) as f: text_with_slash_t = f.read() text_tab_separated = slash_t_to_tab_separated( text_with_slash_t) sys_ref_summaries = [ summ.strip() for summ in text_tab_separated.split('\n') if summ.strip() != '' ] reordered_summaries = reorder_list_like(summaries, sys_ref_summaries, reference_summaries) with open(os.path.join(processed_dir, 'summaries.txt'), 'w') as writer: with open( os.path.join(processed_dir, 'summaries_tokenized.txt'), 'w') as writer_tokenized: for summ in reordered_summaries: writer_tokenized.write(summ + '\n') writer.write(fix_punctuations(summ) + '\n') elif system == 'dca': with open(os.path.join(raw_dir, 'cnndm_m6_m7.txt')) as f: text = f.read() lines = text.split('\n') summary_texts = [] sys_ref_summary_texts = [] for line in tqdm(lines[1:]): if line.strip() == '': continue if len(line.split('\t')) != 3: a = 0 sys_ref_summary, _, summary = line.split('\t') summary = summary.replace('u . s .', 'u.s.') sys_ref_summary = sys_ref_summary.replace('u . s .', 'u.s.') summary_texts.append(summary) sys_ref_summary_texts.append(sys_ref_summary) summaries = [get_sents(summary) for summary in tqdm(summary_texts)] sys_ref_summaries = [ get_sents(sys_ref_summary) for sys_ref_summary in tqdm(sys_ref_summary_texts) ] reordered_summaries = reorder_list_like(summaries, sys_ref_summaries, reference_summaries) with open(os.path.join(processed_dir, 'summaries.txt'), 'w') as writer: with open( os.path.join(processed_dir, 'summaries_tokenized.txt'), 'w') as writer_tokenized: for summ in reordered_summaries: writer_tokenized.write(summ + '\n') writer.write(fix_punctuations(summ) + '\n') elif system == 'novel': with open(os.path.join(raw_dir, 'rl-novelty-lm.out')) as f: text = f.read() lines = text.split('\n') summaries = [] sys_articles = [] summary_texts = [] sys_article_texts = [] for line in tqdm(lines): if line.strip() == '': continue obj = json.loads(line) article = obj['article'] summary = obj['prediction'] summary_texts.append( util.unfix_bracket_tokens_in_sent(summary)) sys_article_texts.append( util.unfix_bracket_tokens_in_sent(article)) # nlp_summaries = nlp.pipe(summary_texts) # nlp_sys_articles = nlp.pipe(sys_article_texts) summaries = [ get_sents(summary) for summary in tqdm(summary_texts, total=11490) ] sys_articles = [ get_sents(article) for article in tqdm(sys_article_texts, total=11490) ] reordered_summaries = reorder_list_like(summaries, sys_articles, reference_articles) with open(os.path.join(processed_dir, 'summaries.txt'), 'w') as writer: with open( os.path.join(processed_dir, 'summaries_tokenized.txt'), 'w') as writer_tokenized: for summ in reordered_summaries: writer_tokenized.write(summ + '\n') writer.write(fix_punctuations(summ) + '\n') a = 0
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.summarizer == 'all': summary_methods = list(summarizers.keys()) else: summary_methods = [FLAGS.summarizer] if FLAGS.dataset_name == 'all': dataset_names = datasets else: dataset_names = [FLAGS.dataset_name] sheets_strs = [] for summary_method in summary_methods: summary_fn = summarizers[summary_method] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else '' vocab = Vocab('logs/vocab' + '_' + original_dataset_name, 50000) # create a vocabulary source_dir = os.path.join(data_dir, dataset_name) source_files = sorted( glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in dataset_name or 'newsroom' in dataset_name or 'xsum' in dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) if dataset_name == 'duc_2004': abs_source_dir = os.path.join( os.path.expanduser('~') + '/data/tf_data/with_coref', dataset_name) abs_example_generator = data.example_generator( abs_source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) abs_names_to_types = [('abstract', 'string_list')] triplet_ssi_list = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) if dataset_name == 'duc_2004': abs_example = next(abs_example_generator) groundtruth_summary_texts = util.unpack_tf_example( abs_example, abs_names_to_types) groundtruth_summary_texts = groundtruth_summary_texts[0] groundtruth_summ_sents_list = [[ sent.strip() for sent in data.abstract2sents(abstract) ] for abstract in groundtruth_summary_texts] else: groundtruth_summary_texts = [groundtruth_summary_text] groundtruth_summ_sents_list = [] for groundtruth_summary_text in groundtruth_summary_texts: groundtruth_summ_sents = [ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ] groundtruth_summ_sents_list.append( groundtruth_summ_sents) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) log_dir = os.path.join(log_root, dataset_name + '_' + summary_method) dec_dir = os.path.join(log_dir, 'decoded') ref_dir = os.path.join(log_dir, 'reference') util.create_dirs(dec_dir) util.create_dirs(ref_dir) parser = PlaintextParser.from_string( ' '.join(raw_article_sents), Tokenizer("english")) summarizer = summary_fn() summary = summarizer( parser.document, 5) #Summarize the document with 5 sentences summary = [str(sentence) for sentence in summary] summary_tokenized = [] for sent in summary: summary_tokenized.append(sent.lower()) rouge_functions.write_for_rouge(groundtruth_summ_sents_list, summary_tokenized, example_idx, ref_dir, dec_dir, log=False) decoded_sent_tokens = [ sent.split() for sent in summary_tokenized ] sentence_limit = 2 sys_ssi_list, _, _ = get_simple_source_indices_list( decoded_sent_tokens, article_sent_tokens, vocab, sentence_limit, min_matched_tokens) triplet_ssi_list.append( (groundtruth_similar_source_indices_list, sys_ssi_list, -1)) print('Evaluating Lambdamart model F1 score...') suffix = util.all_sent_selection_eval(triplet_ssi_list) print(suffix) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir) print(("Results_dict: ", results_dict)) sheets_str = rouge_functions.rouge_log(results_dict, log_dir, suffix=suffix) sheets_strs.append(dataset_name + '_' + summary_method + '\n' + sheets_str) for sheets_str in sheets_strs: print(sheets_str + '\n')
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, _, _ = ex print(example_idx) # Read example from dataset raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] enforced_groundtruth_ssi_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]] if FLAGS.upper_bound: # If upper bound, then get the groundtruth singletons/pairs replaced_ssi_list = util.replace_empty_ssis(enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists(replaced_ssi_list) summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) else: # Generates summary based on BERT output. This is an extractive summary. summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx) similar_source_indices_list_trunc = similar_source_indices_list[:ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[:ssi_length_extractive] if example_idx <= 1: summary_sent_tokens = [sent.split(' ') for sent in summary_sents_for_html_trunc] extracted_sents_in_article_html = html_highlight_sents_in_article(summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices) groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article(groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html ssi_functions.write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.dataset_name == 'all': dataset_names = ['cnn_dm', 'xsum', 'duc_2004'] else: dataset_names = [FLAGS.dataset_name] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name source_dir = os.path.join(data_dir, dataset_name) if FLAGS.dataset_split == 'all': if dataset_name == 'duc_2004': dataset_splits = ['test'] else: dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) out_dir = os.path.join('data', 'bert', dataset_name, 'article_embeddings', 'input_article') util.create_dirs(out_dir) writer = open(os.path.join(out_dir, dataset_split) + '.tsv', 'wb') inst_id = 0 for example_idx, example in enumerate(tqdm(example_generator, total=total)): if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) article = ' '.join(raw_article_sents) writer.write((article + '\n').encode())
def load_and_evaluate_example(ex): example, example_idx, single_feat_len, pair_feat_len = ex print(example_idx) # example_idx += 1 raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] # summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')] temp_in_path = os.path.join(temp_in_dir, '%06d.txt' % example_idx) temp_out_path = os.path.join(temp_out_dir, '%06d.txt' % example_idx) if importance: summary_sents, similar_source_indices_list, summary_sents_for_html = generate_summary_importance( raw_article_sents, article_sent_tokens, temp_in_path, temp_out_path, single_feat_len, pair_feat_len) else: summary_sents = generate_summary(raw_article_sents, article_sent_tokens, temp_in_path, temp_out_path, single_feat_len, pair_feat_len) if example_idx <= 100: summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html ] extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list, article_sent_tokens) write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_similar_source_indices_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_similar_source_indices_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html write_highlighted_html(all_html, html_dir, example_idx) rouge_eval_references.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir)