def text_generator(self, example_generator): """Generates article and abstract text from tf.Example. Args: example_generator: a generator of tf.Examples from file. See data.example_generator""" # i = 0 while True: e = next(example_generator) # e is a tf.Example try: names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string_list')] if self._hps.dataset_name == 'duc_2004': names_to_types[2] = ('summary_text', 'string_list') raw_article_sents, ssi, groundtruth_summary_sents = util.unpack_tf_example( e, names_to_types) groundtruth_summary_text = '\n'.join(groundtruth_summary_sents) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] article_text = ' '.join([' '.join(sent) for sent in article_sent_tokens]) if self._hps.dataset_name == 'duc_2004': abstract_sentences = [['<s> ' + sent.strip() + ' </s>' for sent in gt_summ_text.strip().split('\n')] for gt_summ_text in groundtruth_summary_text] abstract_sentences = [abs_sents[:max_dec_sents] for abs_sents in abstract_sentences] abstract_texts = [' '.join(abs_sents) for abs_sents in abstract_sentences] else: abstract_sentences = ['<s> ' + sent.strip() + ' </s>' for sent in groundtruth_summary_text.strip().split('\n')] abstract_sentences = abstract_sentences[:max_dec_sents] abstract_texts = [' '.join(abstract_sentences)] if 'doc_indices' not in e.features.feature or len(e.features.feature['doc_indices'].bytes_list.value) == 0: num_words = len(article_text.split()) doc_indices_text = '0 ' * num_words else: doc_indices_text = e.features.feature['doc_indices'].bytes_list.value[0] sentence_limit = 1 if self._hps.singles_and_pairs == 'singles' else 2 ssi = util.enforce_sentence_limit(ssi, sentence_limit) ssi = ssi[:max_dec_sents] ssi = util.make_ssi_chronological(ssi) except: logging.error('Failed to get article or abstract from example') raise if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1 logging.warning('Found an example with empty article text. Skipping it.\n*********************************************') elif len(article_text.strip().split()) < 3 and self._hps.skip_with_less_than_3: print('Article has less than 3 tokens, so skipping\n*********************************************') elif len(abstract_texts[0].strip().split()) < 3 and self._hps.skip_with_less_than_3: print('Abstract has less than 3 tokens, so skipping\n*********************************************') else: yield (article_text, abstract_texts, doc_indices_text, raw_article_sents, ssi)
def write_to_lambdamart_examples_to_file(ex): example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs = ex print(example_idx) # example_idx += 1 temp_in_path = os.path.join(temp_in_dir, '%06d.txt' % example_idx) if not FLAGS.start_over and os.path.exists(temp_in_path): return raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] if len(doc_indices) != len( util.flatten_list_of_lists(article_sent_tokens)): doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) rel_sent_indices, _, _ = get_rel_sent_indices(doc_indices, article_sent_tokens) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] # summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')] if FLAGS.dataset_name == 'duc_2004': first_k_indices = get_indices_of_first_k_sents_of_each_article( rel_sent_indices, FLAGS.first_k) else: first_k_indices = [idx for idx in range(len(raw_article_sents))] if importance: get_instances(example_idx, raw_article_sents, article_sent_tokens, corefs, rel_sent_indices, first_k_indices, temp_in_path, temp_out_path, single_feat_len, pair_feat_len, singles_and_pairs)
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, _, _ = ex print(example_idx) # Read example from dataset raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] enforced_groundtruth_ssi_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]] if FLAGS.upper_bound: # If upper bound, then get the groundtruth singletons/pairs replaced_ssi_list = util.replace_empty_ssis(enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists(replaced_ssi_list) summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) else: # Generates summary based on BERT output. This is an extractive summary. summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx) similar_source_indices_list_trunc = similar_source_indices_list[:ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[:ssi_length_extractive] if example_idx <= 1: summary_sent_tokens = [sent.split(' ') for sent in summary_sents_for_html_trunc] extracted_sents_in_article_html = html_highlight_sents_in_article(summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices) groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article(groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html ssi_functions.write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted(glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ('cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator(source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) for example_idx, example in enumerate(tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]] if doc_indices is None: doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens) groundtruth_similar_source_indices_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, FLAGS.sentence_limit)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.singles_and_pairs == 'singles': FLAGS.sentence_limit = 1 else: FLAGS.sentence_limit = 2 if FLAGS.dataset_name == 'all': dataset_names = ['cnn_dm', 'xsum', 'duc_2004'] else: dataset_names = [FLAGS.dataset_name] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name source_dir = os.path.join(data_dir, dataset_name) if FLAGS.dataset_split == 'all': if dataset_name == 'duc_2004': dataset_splits = ['test'] else: # dataset_splits = ['val_test', 'test', 'val', 'train'] dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: if dataset_split == 'val_test': source_dataset_split = 'val' else: source_dataset_split = dataset_split source_files = sorted( glob.glob(source_dir + '/' + source_dataset_split + '*')) total = len(source_files) * 1000 example_generator = data.example_generator( source_dir + '/' + source_dataset_split + '*', True, False, should_check_valid=False) out_dir = os.path.join('data', 'bert', dataset_name, FLAGS.singles_and_pairs, 'input') util.create_dirs(out_dir) writer = open(os.path.join(out_dir, dataset_split) + '.tsv', 'wb') header_list = [ 'should_merge', 'sent1', 'sent2', 'example_idx', 'inst_id', 'ssi' ] writer.write(('\t'.join(header_list) + '\n').encode()) inst_id = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent, whitespace=True) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if dataset_name != 'duc_2004' or doc_indices is None or ( dataset_name != 'duc_2004' and len(doc_indices) != len( util.flatten_list_of_lists(article_sent_tokens))): doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] rel_sent_indices, _, _ = ssi_functions.get_rel_sent_indices( doc_indices, article_sent_tokens) similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) possible_pairs = [ x for x in list( itertools.combinations( list(range(len(raw_article_sents))), 2)) ] # all pairs possible_pairs = filter_pairs_by_sent_position( possible_pairs, rel_sent_indices=rel_sent_indices) possible_singles = [(i, ) for i in range(len(raw_article_sents))] positives = [ssi for ssi in similar_source_indices_list] if dataset_split == 'test' or dataset_split == 'val_test': if FLAGS.singles_and_pairs == 'singles': possible_combinations = possible_singles else: possible_combinations = possible_pairs + possible_singles negatives = [ ssi for ssi in possible_combinations if not (ssi in positives or ssi[::-1] in positives) ] for ssi_idx, ssi in enumerate(positives): if len(ssi) == 0: continue if chronological_ssi and len(ssi) >= 2: if ssi[0] > ssi[1]: ssi = (min(ssi), max(ssi)) writer.write( get_string_bert_example(raw_article_sents, ssi, 1, example_idx, inst_id).encode()) inst_id += 1 for ssi in negatives: writer.write( get_string_bert_example(raw_article_sents, ssi, 0, example_idx, inst_id).encode()) inst_id += 1 else: positive_sents = list( set(util.flatten_list_of_lists(positives))) negative_pairs = [ pair for pair in possible_pairs if not any(i in positive_sents for i in pair) ] negative_singles = [ sing for sing in possible_singles if not sing[0] in positive_sents ] random_negative_pairs = np.random.permutation( len(negative_pairs)).tolist() random_negative_singles = np.random.permutation( len(negative_singles)).tolist() for ssi in similar_source_indices_list: if len(ssi) == 0: continue if chronological_ssi and len(ssi) >= 2: if ssi[0] > ssi[1]: ssi = (min(ssi), max(ssi)) is_pair = len(ssi) == 2 writer.write( get_string_bert_example(raw_article_sents, ssi, 1, example_idx, inst_id).encode()) inst_id += 1 # False sentence single/pair if is_pair: if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[ random_negative_pairs.pop()] else: if len(random_negative_singles) == 0: continue negative_indices = negative_singles[ random_negative_singles.pop()] article_lcs_paths = None writer.write( get_string_bert_example(raw_article_sents, negative_indices, 0, example_idx, inst_id).encode()) inst_id += 1
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.summarizer == 'all': summary_methods = list(summarizers.keys()) else: summary_methods = [FLAGS.summarizer] if FLAGS.dataset_name == 'all': dataset_names = datasets else: dataset_names = [FLAGS.dataset_name] sheets_strs = [] for summary_method in summary_methods: summary_fn = summarizers[summary_method] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else '' vocab = Vocab('logs/vocab' + '_' + original_dataset_name, 50000) # create a vocabulary source_dir = os.path.join(data_dir, dataset_name) source_files = sorted( glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in dataset_name or 'newsroom' in dataset_name or 'xsum' in dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) if dataset_name == 'duc_2004': abs_source_dir = os.path.join( os.path.expanduser('~') + '/data/tf_data/with_coref', dataset_name) abs_example_generator = data.example_generator( abs_source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) abs_names_to_types = [('abstract', 'string_list')] triplet_ssi_list = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) if dataset_name == 'duc_2004': abs_example = next(abs_example_generator) groundtruth_summary_texts = util.unpack_tf_example( abs_example, abs_names_to_types) groundtruth_summary_texts = groundtruth_summary_texts[0] groundtruth_summ_sents_list = [[ sent.strip() for sent in data.abstract2sents(abstract) ] for abstract in groundtruth_summary_texts] else: groundtruth_summary_texts = [groundtruth_summary_text] groundtruth_summ_sents_list = [] for groundtruth_summary_text in groundtruth_summary_texts: groundtruth_summ_sents = [ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ] groundtruth_summ_sents_list.append( groundtruth_summ_sents) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) log_dir = os.path.join(log_root, dataset_name + '_' + summary_method) dec_dir = os.path.join(log_dir, 'decoded') ref_dir = os.path.join(log_dir, 'reference') util.create_dirs(dec_dir) util.create_dirs(ref_dir) parser = PlaintextParser.from_string( ' '.join(raw_article_sents), Tokenizer("english")) summarizer = summary_fn() summary = summarizer( parser.document, 5) #Summarize the document with 5 sentences summary = [str(sentence) for sentence in summary] summary_tokenized = [] for sent in summary: summary_tokenized.append(sent.lower()) rouge_functions.write_for_rouge(groundtruth_summ_sents_list, summary_tokenized, example_idx, ref_dir, dec_dir, log=False) decoded_sent_tokens = [ sent.split() for sent in summary_tokenized ] sentence_limit = 2 sys_ssi_list, _, _ = get_simple_source_indices_list( decoded_sent_tokens, article_sent_tokens, vocab, sentence_limit, min_matched_tokens) triplet_ssi_list.append( (groundtruth_similar_source_indices_list, sys_ssi_list, -1)) print('Evaluating Lambdamart model F1 score...') suffix = util.all_sent_selection_eval(triplet_ssi_list) print(suffix) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir) print(("Results_dict: ", results_dict)) sheets_str = rouge_functions.rouge_log(results_dict, log_dir, suffix=suffix) sheets_strs.append(dataset_name + '_' + summary_method + '\n' + sheets_str) for sheets_str in sheets_strs: print(sheets_str + '\n')
def convert_article_to_lambdamart_features(ex): # example_idx += 1 # if num_instances != -1 and example_idx >= num_instances: # break example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs, out_path = ex print(example_idx) raw_article_sents, similar_source_indices_list, summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] if len(doc_indices) != len( util.flatten_list_of_lists(article_sent_tokens)): doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) rel_sent_indices, _, _ = util.get_rel_sent_indices(doc_indices, article_sent_tokens) if FLAGS.singles_and_pairs == 'singles': sentence_limit = 1 else: sentence_limit = 2 similar_source_indices_list = util.enforce_sentence_limit( similar_source_indices_list, sentence_limit) summ_sent_tokens = [ sent.strip().split() for sent in summary_text.strip().split('\n') ] # sent_term_matrix = util.get_tfidf_matrix(raw_article_sents) article_text = ' '.join(raw_article_sents) sent_term_matrix = util.get_doc_substituted_tfidf_matrix( tfidf_vectorizer, raw_article_sents, article_text, pca) doc_vector = np.mean(sent_term_matrix, axis=0) out_str = '' # ssi_idx_cur_inst_id = defaultdict(int) instances = [] if importance: importances = util.special_squash( util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents, pca)) possible_pairs = [ x for x in list( itertools.combinations(list(range(len(raw_article_sents))), 2)) ] # all pairs if FLAGS.use_pair_criteria: possible_pairs = filter_pairs_by_criteria(raw_article_sents, possible_pairs, corefs) if FLAGS.sent_position_criteria: possible_pairs = filter_pairs_by_sent_position( possible_pairs, rel_sent_indices) possible_singles = [(i, ) for i in range(len(raw_article_sents))] possible_combinations = possible_pairs + possible_singles positives = [ssi for ssi in similar_source_indices_list] negatives = [ ssi for ssi in possible_combinations if not (ssi in positives or ssi[::-1] in positives) ] negative_pairs = [ x for x in possible_pairs if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list) ] negative_singles = [ x for x in possible_singles if not (x in similar_source_indices_list or x[::-1] in similar_source_indices_list) ] random_negative_pairs = np.random.permutation( len(negative_pairs)).tolist() random_negative_singles = np.random.permutation( len(negative_singles)).tolist() qid = example_idx for similar_source_indices in positives: # True sentence single/pair relevance = 1 features = get_features(similar_source_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if features is None: continue instances.append( Lambdamart_Instance(features, relevance, qid, similar_source_indices)) a = 0 if FLAGS.dataset_name == 'xsum' and FLAGS.special_xsum_balance: neg_relevance = 0 num_negative = 4 if FLAGS.singles_and_pairs == 'singles': num_neg_singles = num_negative num_neg_pairs = 0 else: num_neg_singles = num_negative / 2 num_neg_pairs = num_negative / 2 for _ in range(num_neg_singles): if len(random_negative_singles) == 0: continue negative_indices = negative_singles[ random_negative_singles.pop()] neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) for _ in range(num_neg_pairs): if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[ random_negative_pairs.pop()] neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) elif balance: # False sentence single/pair is_pair = len(similar_source_indices) == 2 if is_pair: if len(random_negative_pairs) == 0: continue negative_indices = negative_pairs[ random_negative_pairs.pop()] else: if len(random_negative_singles) == 0: continue negative_indices = negative_singles[ random_negative_singles.pop()] neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, rel_sent_indices, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) if not balance: for negative_indices in negatives: neg_relevance = 0 neg_features = get_features(negative_indices, sent_term_matrix, article_sent_tokens, single_feat_len, pair_feat_len, importances, singles_and_pairs) if neg_features is None: continue instances.append( Lambdamart_Instance(neg_features, neg_relevance, qid, negative_indices)) sorted_instances = sorted(instances, key=lambda x: (x.qid, x.source_indices)) assign_inst_ids(sorted_instances) if FLAGS.lr: return sorted_instances else: for instance in sorted_instances: lambdamart_str = format_to_lambdamart(instance, single_feat_len) out_str += lambdamart_str + '\n' with open(os.path.join(out_path, '%06d.txt' % example_idx), 'wb') as f: f.write(out_str)
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, _, _ = ex print(example_idx) # example_idx += 1 qid = example_idx raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] enforced_groundtruth_ssi_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if FLAGS.upper_bound: replaced_ssi_list = util.replace_empty_ssis( enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists( replaced_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) elif FLAGS.lead: lead_ssi_list = [(idx, ) for idx in list( range(util.average_sents_for_dataset[FLAGS.dataset_name]))] lead_ssi_list = lead_ssi_list[:len( raw_article_sents )] # make sure the sentence indices don't go past the total number of sentences in the article selected_article_sent_indices = util.flatten_list_of_lists( lead_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = lead_ssi_list ssi_length_extractive = len(similar_source_indices_list) else: summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary( article_sent_tokens, qid_ssi_to_importances, example_idx) similar_source_indices_list_trunc = similar_source_indices_list[: ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[: ssi_length_extractive] if example_idx <= 100: summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html_trunc ] extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices) # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
def text_generator(self, example_generator): """Generates article and abstract text from tf.Example. Args: example_generator: a generator of tf.Examples from file. See data.example_generator""" # i = 0 while True: # i += 1 e = next(example_generator) # e is a tf.Example abstract_texts = [] raw_article_sents = [] # if self._hps.pg_mmr or '_sent' in self._hps.dataset_name: try: # names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string'), ('corefs', 'json'), ('article_lcs_paths_list', 'delimited_list_of_list_of_lists')] names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string_list'), ('corefs', 'json'), ('article_lcs_paths_list', 'delimited_list_of_list_of_lists')] if self._hps.dataset_name == 'duc_2004': names_to_types[2] = ('summary_text', 'string_list') # raw_article_sents, ssi, groundtruth_summary_text, corefs, article_lcs_paths_list = util.unpack_tf_example( # e, names_to_types) raw_article_sents, ssi, groundtruth_summary_sents, corefs, article_lcs_paths_list = util.unpack_tf_example( e, names_to_types) groundtruth_summary_text = '\n'.join(groundtruth_summary_sents) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] article_text = ' '.join( [' '.join(sent) for sent in article_sent_tokens]) if self._hps.dataset_name == 'duc_2004': abstract_sentences = [[ '<s> ' + sent.strip() + ' </s>' for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] abstract_sentences = [ abs_sents[:max_dec_sents] for abs_sents in abstract_sentences ] abstract_texts = [ ' '.join(abs_sents) for abs_sents in abstract_sentences ] else: abstract_sentences = [ '<s> ' + sent.strip() + ' </s>' for sent in groundtruth_summary_text.strip().split('\n') ] abstract_sentences = abstract_sentences[:max_dec_sents] abstract_texts = [' '.join(abstract_sentences)] if 'doc_indices' not in e.features.feature or len( e.features.feature['doc_indices'].bytes_list.value ) == 0: num_words = len(article_text.split()) doc_indices_text = '0 ' * num_words else: doc_indices_text = e.features.feature[ 'doc_indices'].bytes_list.value[0] sentence_limit = 1 if self._hps.singles_and_pairs == 'singles' else 2 ssi = util.enforce_sentence_limit(ssi, sentence_limit) ssi = ssi[:max_dec_sents] article_lcs_paths_list = util.enforce_sentence_limit( article_lcs_paths_list, sentence_limit) article_lcs_paths_list = article_lcs_paths_list[:max_dec_sents] ssi, article_lcs_paths_list = util.make_ssi_chronological( ssi, article_lcs_paths_list) except: logging.error('Failed to get article or abstract from example') raise # continue # else: # try: # article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files # for abstract in e.features.feature['abstract'].bytes_list.value: # abstract_texts.append(abstract) # the abstract text was saved under the key 'abstract' in the data files # if 'doc_indices' not in e.features.feature or len(e.features.feature['doc_indices'].bytes_list.value) == 0: # num_words = len(article_text.split()) # doc_indices_text = '0 ' * num_words # else: # doc_indices_text = e.features.feature['doc_indices'].bytes_list.value[0] # for sent in e.features.feature['raw_article_sents'].bytes_list.value: # raw_article_sents.append(sent) # the abstract text was saved under the key 'abstract' in the data files # ssi = None # article_lcs_paths_list = None # except ValueError: # logging.error('Failed to get article or abstract from example\n*********************************************') # raise # # continue if len( article_text ) == 0: # See https://github.com/abisee/pointer-generator/issues/1 logging.warning( 'Found an example with empty article text. Skipping it.\n*********************************************' ) elif len(article_text.strip().split() ) < 3 and self._hps.skip_with_less_than_3: print( 'Article has less than 3 tokens, so skipping\n*********************************************' ) elif len(abstract_texts[0].strip().split() ) < 3 and self._hps.skip_with_less_than_3: print( 'Abstract has less than 3 tokens, so skipping\n*********************************************' ) else: # print i yield (article_text, abstract_texts, doc_indices_text, raw_article_sents, ssi, article_lcs_paths_list)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) out_dir = os.path.join( os.path.expanduser('~') + '/data/kaiqiang_data', FLAGS.dataset_name) if FLAGS.mode == 'write': util.create_dirs(out_dir) if FLAGS.dataset_name == 'duc_2004': dataset_splits = ['test'] elif FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: if dataset_split == 'test': ssi_data_path = os.path.join( 'logs/%s_bert_both_sentemb_artemb_plushidden' % FLAGS.dataset_name, 'ssi.pkl') print(util.bcolors.OKGREEN + "Loading SSI from BERT at %s" % ssi_data_path + util.bcolors.ENDC) with open(ssi_data_path) as f: ssi_triple_list = pickle.load(f) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted( glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) out_document_path = os.path.join(out_dir, dataset_split + '.Ndocument') out_summary_path = os.path.join(out_dir, dataset_split + '.Nsummary') out_example_idx_path = os.path.join(out_dir, dataset_split + '.Nexampleidx') doc_writer = open(out_document_path, 'w') if dataset_split != 'test': sum_writer = open(out_summary_path, 'w') ex_idx_writer = open(out_example_idx_path, 'w') for example_idx, example in enumerate( tqdm(example_generator, total=total)): if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens) if dataset_split == 'test': if example_idx >= len(ssi_triple_list): raise Exception( 'Len of ssi list (%d) is less than number of examples (>=%d)' % (len(ssi_triple_list), example_idx)) ssi_length_extractive = ssi_triple_list[example_idx][2] if ssi_length_extractive > 1: a = 0 ssi = ssi_triple_list[example_idx][1] ssi = ssi[:ssi_length_extractive] groundtruth_similar_source_indices_list = ssi else: groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) for ssi_idx, ssi in enumerate( groundtruth_similar_source_indices_list): if len(ssi) == 0: continue my_article = ' '.join(util.reorder(raw_article_sents, ssi)) doc_writer.write(my_article + '\n') if dataset_split != 'test': sum_writer.write(groundtruth_summ_sents[0][ssi_idx] + '\n') ex_idx_writer.write(str(example_idx) + '\n') elif FLAGS.mode == 'evaluate': summary_dir = '/home/logan/data/kaiqiang_data/logan_ACL/trained_on_' + FLAGS.train_dataset + '/' + FLAGS.dataset_name out_summary_path = os.path.join(summary_dir, 'test' + 'Summary.txt') out_example_idx_path = os.path.join(out_dir, 'test' + '.Nexampleidx') decode_dir = 'logs/kaiqiang_%s_trainedon%s' % (FLAGS.dataset_name, FLAGS.train_dataset) rouge_ref_dir = os.path.join(decode_dir, 'reference') rouge_dec_dir = os.path.join(decode_dir, 'decoded') util.create_dirs(rouge_ref_dir) util.create_dirs(rouge_dec_dir) def num_lines_in_file(file_path): with open(file_path) as f: num_lines = sum(1 for line in f) return num_lines def process_example(sents, ex_idx, groundtruth_summ_sents): final_decoded_words = [] for sent in sents: final_decoded_words.extend(sent.split(' ')) rouge_functions.write_for_rouge(groundtruth_summ_sents, None, ex_idx, rouge_ref_dir, rouge_dec_dir, decoded_words=final_decoded_words, log=False) num_lines_summary = num_lines_in_file(out_summary_path) num_lines_example_indices = num_lines_in_file(out_example_idx_path) if num_lines_summary != num_lines_example_indices: raise Exception( 'Num lines summary != num lines example indices: (%d, %d)' % (num_lines_summary, num_lines_example_indices)) source_dir = os.path.join(data_dir, FLAGS.dataset_name) example_generator = data.example_generator(source_dir + '/' + 'test' + '*', True, False, should_check_valid=False) sum_writer = open(out_summary_path) ex_idx_writer = open(out_example_idx_path) prev_ex_idx = 0 sents = [] for line_idx in tqdm(range(num_lines_summary)): line = sum_writer.readline() ex_idx = int(ex_idx_writer.readline()) if ex_idx == prev_ex_idx: sents.append(line) else: example = example_generator.next() raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] process_example(sents, ex_idx, groundtruth_summ_sents) prev_ex_idx = ex_idx sents = [line] example = example_generator.next() raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] process_example(sents, ex_idx, groundtruth_summ_sents) print("Now starting ROUGE eval...") if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 results_dict = rouge_functions.rouge_eval(rouge_ref_dir, rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, decode_dir) else: raise Exception('mode flag was not evaluate or write.')
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents) else: gt_ssi, sys_ssi, ext_len = ssi_list[example_idx] if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) if gt_ssi != groundtruth_similar_source_indices_list: print( 'Warning: Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '' for ssi_idx, ssi in enumerate(sys_ssi): selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, FLAGS.batch_size, hps, self._vocab) decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 1000: min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>' if len(final_decoded_words) >= 100: break if example_idx < 1000: self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): attn_vis_idx = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list sys_alp_list = groundtruth_article_lcs_paths_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) sys_ssi, sys_alp_list = util.replace_empty_ssis( sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list) else: gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[ example_idx] sys_alp_list = ssi_functions.list_labels_from_probs( sys_token_probs_list, FLAGS.tag_threshold) if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) # if gt_ssi != groundtruth_similar_source_indices_list: # raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '<u>System Summary</u><br><br>' for ssi_idx, ssi in enumerate(sys_ssi): # selected_article_lcs_paths = None selected_article_lcs_paths = sys_alp_list[ssi_idx] ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) selected_article_lcs_paths = [selected_article_lcs_paths] selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, selected_article_lcs_paths, FLAGS.batch_size, hps, self._vocab) original_article = batch.original_articles[0] # string original_abstract = batch.original_abstracts[0] # string article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs( original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string # article_withunks = data.show_art_oovs(original_article, self._vocab) # string # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string if FLAGS.first_intact and ssi_idx == 0: decoded_words = selected_article_text.strip().split() decoded_output = selected_article_text else: decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += highlighted_html + '<br>' if FLAGS.attn_vis and example_idx < 200: self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx ) # write info to .json file for visualization tool attn_vis_idx += 1 if len(final_decoded_words) >= 100: break gt_ssi_list, gt_alp_list = util.replace_empty_ssis( groundtruth_similar_source_indices_list, raw_article_sents, sys_alp_list=groundtruth_article_lcs_paths_list) highlight_html_gt = '<u>Reference Summary</u><br><br>' for ssi_idx, ssi in enumerate(gt_ssi_list): selected_article_lcs_paths = gt_alp_list[ssi_idx] try: ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) except: util.print_vars(ssi, example_idx, selected_article_lcs_paths) raise selected_raw_article_sents = util.reorder( raw_article_sents, ssi) if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [ groundtruth_summ_sent_tokens[ssi_idx] ] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_gt += highlighted_html + '<br>' if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) highlight_html_total = ssi_functions.put_html_in_two_columns( highlight_html_total, highlight_html_gt) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) # if example_idx % 100 == 0: # attn_dir = os.path.join(self._decode_dir, 'attn_vis_data') # attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later # if FLAGS.attn_vis: # self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def evaluate_example(ex): example, example_idx, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings = ex print(example_idx) # example_idx += 1 qid = example_idx raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] enforced_groundtruth_ssi_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, sentence_limit) groundtruth_summ_sent_tokens = [] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if FLAGS.upper_bound: replaced_ssi_list = util.replace_empty_ssis( enforced_groundtruth_ssi_list, raw_article_sents) selected_article_sent_indices = util.flatten_list_of_lists( replaced_ssi_list) summary_sents = [ ' '.join(sent) for sent in util.reorder( article_sent_tokens, selected_article_sent_indices) ] similar_source_indices_list = groundtruth_similar_source_indices_list ssi_length_extractive = len(similar_source_indices_list) else: summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, \ article_lcs_paths_list, token_probs_list = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx, qid_ssi_to_token_scores_and_mappings) similar_source_indices_list_trunc = similar_source_indices_list[: ssi_length_extractive] summary_sents_for_html_trunc = summary_sents_for_html[: ssi_length_extractive] if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): summary_sent_tokens = [ sent.split(' ') for sent in summary_sents_for_html_trunc ] if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0: lcs_paths_list_param = copy.deepcopy(article_lcs_paths_list) else: lcs_paths_list_param = None extracted_sents_in_article_html = html_highlight_sents_in_article( summary_sent_tokens, similar_source_indices_list_trunc, article_sent_tokens, doc_indices=doc_indices, lcs_paths_list=lcs_paths_list_param) # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx) groundtruth_ssi_list, gt_lcs_paths_list, gt_article_lcs_paths_list, gt_smooth_article_paths_list = get_simple_source_indices_list( groundtruth_summ_sent_tokens, article_sent_tokens, None, sentence_limit, min_matched_tokens) groundtruth_highlighted_html = html_highlight_sents_in_article( groundtruth_summ_sent_tokens, groundtruth_ssi_list, article_sent_tokens, lcs_paths_list=gt_lcs_paths_list, article_lcs_paths_list=gt_smooth_article_paths_list, doc_indices=doc_indices) all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html # all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html write_highlighted_html(all_html, html_dir, example_idx) rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir) return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive, token_probs_list)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.dataset_name == 'all': dataset_names = ['cnn_dm', 'xsum', 'duc_2004'] else: dataset_names = [FLAGS.dataset_name] if not os.path.exists(plot_data_file): all_lists_of_histogram_pairs = [] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name if dataset_name == 'duc_2004': dataset_splits = ['test'] elif FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] ssi_list = [] for dataset_split in dataset_splits: ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name, dataset_split + '_ssi.pkl') with open(ssi_path) as f: ssi_list.extend(pickle.load(f)) if FLAGS.dataset_name == 'duc_2004': for abstract_idx in [1, 2, 3]: ssi_path = os.path.join( ssi_dir, FLAGS.dataset_name, dataset_split + '_ssi_' + str(abstract_idx) + '.pkl') with open(ssi_path) as f: temp_ssi_list = pickle.load(f) ssi_list.extend(temp_ssi_list) ssi_2d = util.flatten_list_of_lists(ssi_list) num_extracted = [ len(ssi) for ssi in util.flatten_list_of_lists(ssi_list) ] hist_num_extracted = np.histogram(num_extracted, bins=6, range=(0, 5)) print(hist_num_extracted) print('Histogram of number of sentences merged: ' + util.hist_as_pdf_str(hist_num_extracted)) distances = [ abs(ssi[0] - ssi[1]) for ssi in ssi_2d if len(ssi) >= 2 ] print('Distance between sentences (mean, median): ', np.mean(distances), np.median(distances)) hist_dist = np.histogram(distances, bins=max(distances)) print('Histogram of distances: ' + util.hist_as_pdf_str(hist_dist)) summ_sent_idx_to_number_of_source_sents = [[], [], [], [], [], [], [], [], [], []] for ssi in ssi_list: for summ_sent_idx, source_indices in enumerate(ssi): if len(source_indices) == 0 or summ_sent_idx >= len( summ_sent_idx_to_number_of_source_sents): continue num_sents = len(source_indices) if num_sents > 2: num_sents = 2 summ_sent_idx_to_number_of_source_sents[ summ_sent_idx].append(num_sents) print( "Number of source sents for summary sentence indices (Is the first summary sent more likely to match with a singleton or a pair?):" ) for summ_sent_idx, list_of_numbers_of_source_sents in enumerate( summ_sent_idx_to_number_of_source_sents): if len(list_of_numbers_of_source_sents) == 0: percent_singleton = 0. else: percent_singleton = list_of_numbers_of_source_sents.count( 1) * 1. / len(list_of_numbers_of_source_sents) percent_pair = list_of_numbers_of_source_sents.count( 2) * 1. / len(list_of_numbers_of_source_sents) print str(percent_singleton) + '\t', print '' for summ_sent_idx, list_of_numbers_of_source_sents in enumerate( summ_sent_idx_to_number_of_source_sents): if len(list_of_numbers_of_source_sents) == 0: percent_pair = 0. else: percent_singleton = list_of_numbers_of_source_sents.count( 1) * 1. / len(list_of_numbers_of_source_sents) percent_pair = list_of_numbers_of_source_sents.count( 2) * 1. / len(list_of_numbers_of_source_sents) print str(percent_pair) + '\t', print '' primary_pos = [ssi[0] for ssi in ssi_2d if len(ssi) >= 1] secondary_pos = [ssi[1] for ssi in ssi_2d if len(ssi) >= 2] all_pos = [max(ssi) for ssi in ssi_2d if len(ssi) >= 1] # if FLAGS.dataset_name != 'duc_2004': # plot_positions(primary_pos, secondary_pos, all_pos) if FLAGS.dataset_split == 'all': glob_string = '*.bin' else: glob_string = dataset_splits[0] print('Loading TFIDF vectorizer') with open(tfidf_vec_path, 'rb') as f: tfidf_vectorizer = pickle.load(f) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted( glob.glob(source_dir + '/' + glob_string + '*')) total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + glob_string + '*', True, False, should_check_valid=False) all_possible_singles = 0 all_possible_pairs = [0] all_filtered_pairs = 0 all_all_combinations = 0 all_ssi_pairs = [0] ssi_pairs_with_shared_coref = [0] ssi_pairs_with_shared_word = [0] ssi_pairs_with_either_coref_or_word = [0] all_pairs_with_shared_coref = [0] all_pairs_with_shared_word = [0] all_pairs_with_either_coref_or_word = [0] actual_total = [0] rel_positions_primary = [] rel_positions_secondary = [] rel_positions_all = [] sent_lens = [] all_sent_lens = [] all_pos = [] y = [] normalized_positions_primary = [] normalized_positions_secondary = [] all_normalized_positions_primary = [] all_normalized_positions_secondary = [] normalized_positions_singles = [] normalized_positions_pairs_first = [] normalized_positions_pairs_second = [] primary_pos_duc = [] secondary_pos_duc = [] all_pos_duc = [] all_distances = [] distances_duc = [] tfidf_similarities = [] all_tfidf_similarities = [] average_mmrs = [] all_average_mmrs = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): # def process(example_idx_example): # # print '0' # example = example_idx_example if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] article_text = ' '.join(raw_article_sents) groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] rel_sent_indices, doc_sent_indices, doc_sent_lens = preprocess_for_lambdamart_no_flags.get_rel_sent_indices( doc_indices, article_sent_tokens) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) sent_term_matrix = util.get_doc_substituted_tfidf_matrix( tfidf_vectorizer, raw_article_sents, article_text) sents_similarities = util.cosine_similarity( sent_term_matrix, sent_term_matrix) importances = util.special_squash( util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents)) if FLAGS.dataset_name == 'duc_2004': first_k_indices = lambdamart_scores_to_summaries.get_indices_of_first_k_sents_of_each_article( rel_sent_indices, FLAGS.first_k) else: first_k_indices = [ idx for idx in range(len(raw_article_sents)) ] article_indices = list(range(len(raw_article_sents))) possible_pairs = [ x for x in list(itertools.combinations(article_indices, 2)) ] # all pairs # # # filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_criteria(raw_article_sents, possible_pairs, corefs) # if FLAGS.dataset_name == 'duc_2004': # filtered_possible_pairs = [x for x in list(itertools.combinations(first_k_indices, 2))] # all pairs # else: # filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_sent_position(possible_pairs) # # removed_pairs = list(set(possible_pairs) - set(filtered_possible_pairs)) # possible_singles = [(i,) for i in range(len(raw_article_sents))] # all_combinations = filtered_possible_pairs + possible_singles # # all_possible_singles += len(possible_singles) # all_possible_pairs[0] += len(possible_pairs) # all_filtered_pairs += len(filtered_possible_pairs) # all_all_combinations += len(all_combinations) # for ssi in groundtruth_similar_source_indices_list: # if len(ssi) > 0: # idx = rel_sent_indices[ssi[0]] # rel_positions_primary.append(idx) # rel_positions_all.append(idx) # if len(ssi) > 1: # idx = rel_sent_indices[ssi[1]] # rel_positions_secondary.append(idx) # rel_positions_all.append(idx) # # # # coref_pairs = preprocess_for_lambdamart_no_flags.get_coref_pairs(corefs) # # DO OVER LAP PAIRS BETTER # overlap_pairs = preprocess_for_lambdamart_no_flags.filter_by_overlap(article_sent_tokens, possible_pairs) # either_coref_or_word = list(set(list(coref_pairs) + overlap_pairs)) # # for ssi in groundtruth_similar_source_indices_list: # if len(ssi) == 2: # all_ssi_pairs[0] += 1 # do_share_coref = ssi in coref_pairs # do_share_words = ssi in overlap_pairs # if do_share_coref: # ssi_pairs_with_shared_coref[0] += 1 # if do_share_words: # ssi_pairs_with_shared_word[0] += 1 # if do_share_coref or do_share_words: # ssi_pairs_with_either_coref_or_word[0] += 1 # all_pairs_with_shared_coref[0] += len(coref_pairs) # all_pairs_with_shared_word[0] += len(overlap_pairs) # all_pairs_with_either_coref_or_word[0] += len(either_coref_or_word) if FLAGS.dataset_name == 'duc_2004': primary_pos_duc.extend([ rel_sent_indices[ssi[0]] for ssi in groundtruth_similar_source_indices_list if len(ssi) >= 1 ]) secondary_pos_duc.extend([ rel_sent_indices[ssi[1]] for ssi in groundtruth_similar_source_indices_list if len(ssi) >= 2 ]) all_pos_duc.extend([ max([rel_sent_indices[sent_idx] for sent_idx in ssi]) for ssi in groundtruth_similar_source_indices_list if len(ssi) >= 1 ]) for ssi in groundtruth_similar_source_indices_list: for sent_idx in ssi: sent_lens.append(len(article_sent_tokens[sent_idx])) if len(ssi) >= 1: orig_val = ssi[0] vals_to_add = get_integral_values_for_histogram( orig_val, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_primary.extend(vals_to_add) if len(ssi) >= 2: orig_val = ssi[1] vals_to_add = get_integral_values_for_histogram( orig_val, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_secondary.extend(vals_to_add) if FLAGS.dataset_name == 'duc_2004': distances_duc.append( abs(rel_sent_indices[ssi[1]] - rel_sent_indices[ssi[0]])) tfidf_similarities.append(sents_similarities[ssi[0], ssi[1]]) average_mmrs.append( (importances[ssi[0]] + importances[ssi[1]]) / 2) for ssi in groundtruth_similar_source_indices_list: if len(ssi) == 1: orig_val = ssi[0] vals_to_add = get_integral_values_for_histogram( orig_val, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_singles.extend(vals_to_add) if len(ssi) >= 2: if doc_sent_indices[ssi[0]] != doc_sent_indices[ ssi[1]]: continue orig_val_first = min(ssi[0], ssi[1]) vals_to_add = get_integral_values_for_histogram( orig_val_first, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_pairs_first.extend(vals_to_add) orig_val_second = max(ssi[0], ssi[1]) vals_to_add = get_integral_values_for_histogram( orig_val_second, rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) normalized_positions_pairs_second.extend(vals_to_add) # all_normalized_positions_primary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(single[0], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for single in possible_singles])) # all_normalized_positions_secondary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(pair[1], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for pair in possible_pairs])) all_sent_lens.extend( [len(sent) for sent in article_sent_tokens]) all_distances.extend([ abs(rel_sent_indices[pair[1]] - rel_sent_indices[pair[0]]) for pair in possible_pairs ]) all_tfidf_similarities.extend([ sents_similarities[pair[0], pair[1]] for pair in possible_pairs ]) all_average_mmrs.extend([ (importances[pair[0]] + importances[pair[1]]) / 2 for pair in possible_pairs ]) # if FLAGS.dataset_name == 'duc_2004': # rel_pos_single = [rel_sent_indices[single[0]] for single in possible_singles] # rel_pos_pair = [[rel_sent_indices[pair[0]], rel_sent_indices[pair[1]]] for pair in possible_pairs] # all_pos.extend(rel_pos_single) # all_pos.extend([max(pair) for pair in rel_pos_pair]) # else: # all_pos.extend(util.flatten_list_of_lists(possible_singles)) # all_pos.extend([max(pair) for pair in possible_pairs]) # y.extend([1 if single in groundtruth_similar_source_indices_list else 0 for single in possible_singles]) # y.extend([1 if pair in groundtruth_similar_source_indices_list else 0 for pair in possible_pairs]) # actual_total[0] += 1 # # p = Pool(144) # # list(tqdm(p.imap(process, example_generator), total=total)) # # # print 'Possible_singles\tPossible_pairs\tFiltered_pairs\tAll_combinations: \n%.2f\t%.2f\t%.2f\t%.2f' % (all_possible_singles*1./actual_total, \ # # all_possible_pairs*1./actual_total, all_filtered_pairs*1./actual_total, all_all_combinations*1./actual_total) # # # # # print 'Relative positions of groundtruth source sentences in document:\nPrimary\tSecondary\tBoth\n%.2f\t%.2f\t%.2f' % (np.mean(rel_positions_primary), np.mean(rel_positions_secondary), np.mean(rel_positions_all)) # # # # print 'SSI Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \ # # % (ssi_pairs_with_shared_coref[0]*100./all_ssi_pairs[0], ssi_pairs_with_shared_word[0]*100./all_ssi_pairs[0], ssi_pairs_with_either_coref_or_word[0]*100./all_ssi_pairs[0]) # # print 'All Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \ # # % (all_pairs_with_shared_coref[0]*100./all_possible_pairs[0], all_pairs_with_shared_word[0]*100./all_possible_pairs[0], all_pairs_with_either_coref_or_word[0]*100./all_possible_pairs[0]) # # # hist_all_pos = np.histogram(all_pos, bins=max(all_pos)+1) # # print 'Histogram of all sent positions: ', util.hist_as_pdf_str(hist_all_pos) # # min_sent_len = min(sent_lens) # # hist_sent_lens = np.histogram(sent_lens, bins=max(sent_lens)-min_sent_len+1) # # print 'min, max sent lens:', min_sent_len, max(sent_lens) # # print 'Histogram of sent lens: ', util.hist_as_pdf_str(hist_sent_lens) # # min_all_sent_len = min(all_sent_lens) # # hist_all_sent_lens = np.histogram(all_sent_lens, bins=max(all_sent_lens)-min_all_sent_len+1) # # print 'min, max all sent lens:', min_all_sent_len, max(all_sent_lens) # # print 'Histogram of all sent lens: ', util.hist_as_pdf_str(hist_all_sent_lens) # # # print 'Pearsons r, p value', pearsonr(all_pos, y) # # fig, ax1 = plt.subplots(nrows=1) # # plt.scatter(all_pos, y) # # pp = PdfPages(os.path.join('stuff/plots', FLAGS.dataset_name + '_position_scatter.pdf')) # # plt.savefig(pp, format='pdf',bbox_inches='tight') # # plt.show() # # pp.close() # # # if FLAGS.dataset_name == 'duc_2004': # # plot_positions(primary_pos_duc, secondary_pos_duc, all_pos_duc) # # normalized_positions_all = normalized_positions_primary + normalized_positions_secondary # # plot_histogram(normalized_positions_primary, num_bins=100) # # plot_histogram(normalized_positions_secondary, num_bins=100) # # plot_histogram(normalized_positions_all, num_bins=100) # # sent_lens_together = [sent_lens, all_sent_lens] # # plot_histogram(sent_lens_together, pdf=True, start_at_0=True, max_val=70) # # if FLAGS.dataset_name == 'duc_2004': # distances = distances_duc # sent_distances_together = [distances, all_distances] # # plot_histogram(sent_distances_together, pdf=True, start_at_0=True, max_val=100) # # tfidf_similarities_together = [tfidf_similarities, all_tfidf_similarities] # # plot_histogram(tfidf_similarities_together, pdf=True, num_bins=100) # # average_mmrs_together = [average_mmrs, all_average_mmrs] # # plot_histogram(average_mmrs_together, pdf=True, num_bins=100) # # normalized_positions_primary_together = [normalized_positions_primary, bin_values] # normalized_positions_secondary_together = [normalized_positions_secondary, bin_values] # # plot_histogram(normalized_positions_primary_together, pdf=True, num_bins=100) # # plot_histogram(normalized_positions_secondary_together, pdf=True, num_bins=100) # # # list_of_hist_pairs = [ # { # 'lst': normalized_positions_primary_together, # 'pdf': True, # 'num_bins': 100, # 'y_lim': 3.9, # 'y_label': FLAGS.dataset_name, # 'x_label': 'Sent position (primary)' # }, # { # 'lst': normalized_positions_secondary_together, # 'pdf': True, # 'num_bins': 100, # 'y_lim': 3.9, # 'x_label': 'Sent position (secondary)' # }, # { # 'lst': sent_distances_together, # 'pdf': True, # 'start_at_0': True, # 'max_val': 100, # 'x_label': 'Sent distance' # }, # { # 'lst': sent_lens_together, # 'pdf': True, # 'start_at_0': True, # 'max_val': 70, # 'x_label': 'Sent length' # }, # { # 'lst': average_mmrs_together, # 'pdf': True, # 'num_bins': 100, # 'x_label': 'Average TF-IDF importance' # } # ] normalized_positions_pairs_together = [ normalized_positions_pairs_first, normalized_positions_pairs_second ] list_of_hist_pairs = [ { 'lst': [normalized_positions_singles], 'pdf': True, 'num_bins': 100, # 'y_lim': 3.9, 'x_lim': 1.0, 'y_label': FLAGS.dataset_name, 'x_label': 'Sent Position (Singles)', 'legend_labels': ['Primary'] }, { 'lst': normalized_positions_pairs_together, 'pdf': True, 'num_bins': 100, # 'y_lim': 3.9, 'x_lim': 1.0, 'x_label': 'Sent Position (Pairs)', 'legend_labels': ['Primary', 'Secondary'] } ] all_lists_of_histogram_pairs.append(list_of_hist_pairs) with open(plot_data_file, 'w') as f: cPickle.dump(all_lists_of_histogram_pairs, f) else: with open(plot_data_file) as f: all_lists_of_histogram_pairs = cPickle.load(f) plot_histograms(all_lists_of_histogram_pairs)
def main(unused_argv): print('Running statistics on %s' % FLAGS.exp_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.all_actions: FLAGS.sent_dataset = True FLAGS.ssi_dataset = True FLAGS.print_output = True FLAGS.highlight = True original_dataset_name = 'xsum' if 'xsum' in FLAGS.dataset_name else 'cnn_dm' if ( 'cnn_dm' in FLAGS.dataset_name or 'duc_2004' in FLAGS.dataset_name) else '' vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size) # create a vocabulary source_dir = os.path.join(data_dir, FLAGS.dataset_name) util.create_dirs(html_dir) if FLAGS.dataset_split == 'all': if FLAGS.dataset_name == 'duc_2004': dataset_splits = ['test'] else: dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) if FLAGS.exp_name == 'reference': # summary_dir = log_dir + default_exp_name + '/decode_test_' + str(max_enc_steps) + \ # 'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/reference' # summary_files = sorted(glob.glob(summary_dir + '/*_reference.A.txt')) summary_dir = source_dir summary_files = source_files else: if FLAGS.exp_name == 'cnn_dm': summary_dir = log_dir + FLAGS.exp_name + '/decode_test_400maxenc_4beam_35mindec_100maxdec_ckpt-238410/decoded' else: ckpt_folder = util.find_largest_ckpt_folder(log_dir + FLAGS.exp_name) summary_dir = log_dir + FLAGS.exp_name + '/' + ckpt_folder + '/decoded' # summary_dir = log_dir + FLAGS.exp_name + '/decode_test_' + str(max_enc_steps) + \ # 'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/decoded' summary_files = sorted(glob.glob(summary_dir + '/*')) if len(summary_files) == 0: raise Exception('No files found in %s' % summary_dir) example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, is_original=True) pros = { 'annotators': 'dcoref', 'outputFormat': 'json', 'timeout': '5000000' } all_merge_examples = [] num_extracted_list = [] distances = [] relative_distances = [] html_str = '' extracted_sents_in_article_html = '' name = FLAGS.dataset_name + '_' + FLAGS.exp_name if FLAGS.coreference_replacement: name += '_coref' highlight_file_name = os.path.join( html_dir, FLAGS.dataset_name + '_' + FLAGS.exp_name) if FLAGS.consider_stopwords: highlight_file_name += '_stopwords' if FLAGS.highlight: extracted_sents_in_article_html_file = open( highlight_file_name + '_extracted_sents.html', 'wb') if FLAGS.kaiqiang: kaiqiang_article_texts = [] kaiqiang_abstract_texts = [] util.create_dirs(kaiqiang_dir) kaiqiang_article_file = open( os.path.join( kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split + '_' + str(FLAGS.min_matched_tokens) + '_articles.txt'), 'wb') kaiqiang_abstract_file = open( os.path.join( kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split + '_' + str(FLAGS.min_matched_tokens) + '_abstracts.txt'), 'wb') if FLAGS.ssi_dataset: if FLAGS.tag_tokens: with_coref_and_ssi_dir = lambdamart_dir + '_and_tag_tokens' else: with_coref_and_ssi_dir = lambdamart_dir lambdamart_out_dir = os.path.join(with_coref_and_ssi_dir, FLAGS.dataset_name) if FLAGS.sentence_limit == 1: lambdamart_out_dir += '_singles' if FLAGS.consider_stopwords: lambdamart_out_dir += '_stopwords' lambdamart_out_full_dir = os.path.join(lambdamart_out_dir, 'all') util.create_dirs(lambdamart_out_full_dir) lambdamart_writer = open( os.path.join(lambdamart_out_full_dir, dataset_split + '.bin'), 'wb') simple_similar_source_indices_list_plus_empty = [] example_idx = -1 instance_idx = 0 total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) random_choices = None if FLAGS.randomize: if FLAGS.dataset_name == 'cnn_dm': list_order = np.random.permutation(11490) random_choices = list_order[:FLAGS.num_instances] for example in tqdm(example_generator, total=total): example_idx += 1 if FLAGS.num_instances != -1 and instance_idx >= FLAGS.num_instances: break if random_choices is not None and example_idx not in random_choices: continue # for file_idx in tqdm(range(len(source_files))): # example = get_tf_example(source_files[file_idx]) article_text = example.features.feature[ 'article'].bytes_list.value[0].decode().lower() if FLAGS.exp_name == 'reference': summary_text, all_summary_texts = get_summary_from_example( example) else: summary_text = get_summary_text(summary_files[example_idx]) article_tokens = split_into_tokens(article_text) if 'raw_article_sents' in example.features.feature and len( example.features.feature['raw_article_sents'].bytes_list. value) > 0: raw_article_sents = example.features.feature[ 'raw_article_sents'].bytes_list.value raw_article_sents = [ sent.decode() for sent in raw_article_sents if sent.decode().strip() != '' ] article_sent_tokens = [ util.process_sent(sent, whitespace=True) for sent in raw_article_sents ] else: # article_text = util.to_unicode(article_text) # sent_pros = {'annotators': 'ssplit', 'outputFormat': 'json', 'timeout': '5000000'} # sents_result_dict = nlp.annotate(str(article_text), properties=sent_pros) # article_sent_tokens = [[token['word'] for token in sent['tokens']] for sent in sents_result_dict['sentences']] raw_article_sents = nltk.tokenize.sent_tokenize(article_text) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if FLAGS.top_n_sents != -1: article_sent_tokens = article_sent_tokens[:FLAGS.top_n_sents] raw_article_sents = raw_article_sents[:FLAGS.top_n_sents] article_sents = [' '.join(sent) for sent in article_sent_tokens] try: article_tokens_string = str(' '.join(article_sents)) except: try: article_tokens_string = str(' '.join( [sent.decode('latin-1') for sent in article_sents])) except: raise if len(article_sent_tokens) == 0: continue summary_sent_tokens = split_into_sent_tokens(summary_text) if 'doc_indices' in example.features.feature and len( example.features.feature['doc_indices'].bytes_list.value ) > 0: doc_indices_str = example.features.feature[ 'doc_indices'].bytes_list.value[0].decode() if '1' in doc_indices_str: doc_indices = [ int(x) for x in doc_indices_str.strip().split() ] rel_sent_positions = importance_features.get_sent_indices( article_sent_tokens, doc_indices) else: num_tokens_total = sum( [len(sent) for sent in article_sent_tokens]) rel_sent_positions = list(range(len(raw_article_sents))) doc_indices = [0] * num_tokens_total else: rel_sent_positions = None doc_indices = None doc_indices_str = None if 'corefs' in example.features.feature and len( example.features.feature['corefs'].bytes_list.value) > 0: corefs_str = example.features.feature[ 'corefs'].bytes_list.value[0] corefs = json.loads(corefs_str) # summary_sent_tokens = limit_to_n_tokens(summary_sent_tokens, 100) similar_source_indices_list_plus_empty = [] simple_similar_source_indices, lcs_paths_list, article_lcs_paths_list, smooth_article_paths_list = ssi_functions.get_simple_source_indices_list( summary_sent_tokens, article_sent_tokens, vocab, FLAGS.sentence_limit, FLAGS.min_matched_tokens, not FLAGS.consider_stopwords, lemmatize=FLAGS.lemmatize, multiple_ssi=FLAGS.multiple_ssi) article_paths_parameter = article_lcs_paths_list if FLAGS.tag_tokens else None article_paths_parameter = smooth_article_paths_list if FLAGS.smart_tags else article_paths_parameter restricted_source_indices = util.enforce_sentence_limit( simple_similar_source_indices, FLAGS.sentence_limit) for summ_sent_idx, summ_sent in enumerate(summary_sent_tokens): if FLAGS.sent_dataset: if len(restricted_source_indices[summ_sent_idx]) == 0: continue merge_example = get_merge_example( restricted_source_indices[summ_sent_idx], article_sent_tokens, summ_sent, corefs, article_paths_parameter[summ_sent_idx]) all_merge_examples.append(merge_example) simple_similar_source_indices_list_plus_empty.append( simple_similar_source_indices) if FLAGS.ssi_dataset: summary_text_to_save = [ s for s in all_summary_texts ] if FLAGS.dataset_name == 'duc_2004' else summary_text write_lambdamart_example(simple_similar_source_indices, raw_article_sents, summary_text_to_save, corefs_str, doc_indices_str, article_paths_parameter, lambdamart_writer) if FLAGS.highlight: highlight_article_lcs_paths_list = smooth_article_paths_list if FLAGS.smart_tags else article_lcs_paths_list # simple_ssi_plus_empty = [ [s[0] for s in sim_source_ind] for sim_source_ind in simple_similar_source_indices] extracted_sents_in_article_html = ssi_functions.html_highlight_sents_in_article( summary_sent_tokens, simple_similar_source_indices, article_sent_tokens, doc_indices, lcs_paths_list, highlight_article_lcs_paths_list) extracted_sents_in_article_html_file.write( extracted_sents_in_article_html.encode()) a = 0 instance_idx += 1 if FLAGS.ssi_dataset: lambdamart_writer.close() if FLAGS.dataset_name == 'cnn_dm' or FLAGS.dataset_name == 'newsroom' or FLAGS.dataset_name == 'xsum': chunk_size = 1000 else: chunk_size = 1 util.chunk_file(dataset_split, lambdamart_out_full_dir, lambdamart_out_dir, chunk_size=chunk_size) if FLAGS.sent_dataset: with_coref_dir = data_dir + '_and_tag_tokens' if FLAGS.tag_tokens else data_dir out_dir = os.path.join(with_coref_dir, FLAGS.dataset_name + '_sent') if FLAGS.sentence_limit == 1: out_dir += '_singles' if FLAGS.consider_stopwords: out_dir += '_stopwords' if FLAGS.coreference_replacement: out_dir += '_coref' if FLAGS.top_n_sents != -1: out_dir += '_n=' + str(FLAGS.top_n_sents) util.create_dirs(out_dir) convert_data.write_with_generator(iter(all_merge_examples), len(all_merge_examples), out_dir, dataset_split) if FLAGS.print_output: # html_str = FLAGS.dataset + ' | ' + FLAGS.exp_name + '<br><br><br>' + html_str # save_fusions_to_file(html_str) ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name) if FLAGS.consider_stopwords: ssi_path += '_stopwords' util.create_dirs(ssi_path) if FLAGS.dataset_name == 'duc_2004' and FLAGS.abstract_idx != 0: abstract_idx_str = '_%d' % FLAGS.abstract_idx else: abstract_idx_str = '' with open( os.path.join( ssi_path, dataset_split + '_ssi' + abstract_idx_str + '.pkl'), 'wb') as f: pickle.dump(simple_similar_source_indices_list_plus_empty, f) if FLAGS.kaiqiang: # kaiqiang_article_file.write('\n'.join(kaiqiang_article_texts)) # kaiqiang_abstract_file.write('\n'.join(kaiqiang_abstract_texts)) kaiqiang_article_file.close() kaiqiang_abstract_file.close() if FLAGS.highlight: extracted_sents_in_article_html_file.close() a = 0