def get_importances(model, batch, enc_states, vocab, sess, hps): if FLAGS.pg_mmr: enc_sentences, enc_tokens = batch.tokenized_sents[ 0], batch.word_ids_sents[0] if FLAGS.importance_fn == 'oracle': human_tokens = get_tokens_for_human_summaries( batch, vocab) # list (of 4 human summaries) of list of token ids metric = 'recall' importances_hat = rouge_l_similarity(enc_tokens, human_tokens, vocab, metric=metric) elif FLAGS.importance_fn == 'svr': if FLAGS.importance_fn == 'svr': with open(os.path.join(FLAGS.actual_log_root, 'svr.pickle'), 'rb') as f: svr_model = cPickle.load(f) enc_sent_indices = importance_features.get_sent_indices( enc_sentences, batch.doc_indices[0]) sent_representations_separate = importance_features.get_separate_enc_states( model, sess, enc_sentences, vocab, hps) importances_hat = get_svr_importances( enc_states[0], enc_sentences, enc_sent_indices, svr_model, sent_representations_separate) elif FLAGS.importance_fn == 'tfidf': importances_hat = get_tfidf_importances(batch.raw_article_sents[0]) importances = util.special_squash(importances_hat) else: importances = None return importances
def calc_importance_features(self, data_path, hps, model_save_path, docs_desired): """Calculate sentence-level features and save as a dataset""" data_path_filter_name = os.path.basename(data_path) if 'train' in data_path_filter_name: data_split = 'train' elif 'val' in data_path_filter_name: data_split = 'val' elif 'test' in data_path_filter_name: data_split = 'test' else: data_split = 'feats' if 'cnn-dailymail' in data_path: inst_per_file = 1000 else: inst_per_file = 1 filelist = glob.glob(data_path) num_documents_desired = docs_desired pbar = tqdm(initial=0, total=num_documents_desired) instances = [] sentences = [] counter = 0 doc_counter = 0 file_counter = 0 while True: batch = self._batcher.next_batch() # 1 example repeated across batch if doc_counter >= num_documents_desired: save_path = os.path.join(model_save_path, data_split + '_%06d'%file_counter) with open(save_path, 'wb') as f: cPickle.dump(instances, f) print('Saved features at %s' % save_path) return if batch is None: # finished decoding dataset in single_pass mode raise Exception('We havent reached the num docs desired (%d), instead we reached (%d)' % (num_documents_desired, doc_counter)) batch_enc_states, _ = self._model.run_encoder(self._sess, batch) for batch_idx, enc_states in enumerate(batch_enc_states): art_oovs = batch.art_oovs[batch_idx] all_original_abstracts_sents = batch.all_original_abstracts_sents[batch_idx] tokenizer = Tokenizer('english') # List of lists of words enc_sentences, enc_tokens = batch.tokenized_sents[batch_idx], batch.word_ids_sents[batch_idx] enc_sent_indices = importance_features.get_sent_indices(enc_sentences, batch.doc_indices[batch_idx]) enc_sentences_str = [' '.join(sent) for sent in enc_sentences] sent_representations_separate = importance_features.get_separate_enc_states(self._model, self._sess, enc_sentences, self._vocab, hps) sent_indices = enc_sent_indices sent_reps = importance_features.get_importance_features_for_article( enc_states, enc_sentences, sent_indices, tokenizer, sent_representations_separate) y, y_hat = importance_features.get_ROUGE_Ls(art_oovs, all_original_abstracts_sents, self._vocab, enc_tokens) binary_y = importance_features.get_best_ROUGE_L_for_each_abs_sent(art_oovs, all_original_abstracts_sents, self._vocab, enc_tokens) for rep_idx, rep in enumerate(sent_reps): rep.y = y[rep_idx] rep.binary_y = binary_y[rep_idx] for rep_idx, rep in enumerate(sent_reps): # Keep all sentences with importance above threshold. All others will be kept with a probability of prob_to_keep if FLAGS.importance_fn == 'svr': instances.append(rep) sentences.append(sentences) counter += 1 # this is how many examples we've decoded doc_counter += len(batch_enc_states) pbar.update(len(batch_enc_states))
def main(unused_argv): print('Running statistics on %s' % FLAGS.exp_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.all_actions: FLAGS.sent_dataset = True FLAGS.ssi_dataset = True FLAGS.print_output = True FLAGS.highlight = True original_dataset_name = 'xsum' if 'xsum' in FLAGS.dataset_name else 'cnn_dm' if ( 'cnn_dm' in FLAGS.dataset_name or 'duc_2004' in FLAGS.dataset_name) else '' vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size) # create a vocabulary source_dir = os.path.join(data_dir, FLAGS.dataset_name) util.create_dirs(html_dir) if FLAGS.dataset_split == 'all': if FLAGS.dataset_name == 'duc_2004': dataset_splits = ['test'] else: dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) if FLAGS.exp_name == 'reference': # summary_dir = log_dir + default_exp_name + '/decode_test_' + str(max_enc_steps) + \ # 'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/reference' # summary_files = sorted(glob.glob(summary_dir + '/*_reference.A.txt')) summary_dir = source_dir summary_files = source_files else: if FLAGS.exp_name == 'cnn_dm': summary_dir = log_dir + FLAGS.exp_name + '/decode_test_400maxenc_4beam_35mindec_100maxdec_ckpt-238410/decoded' else: ckpt_folder = util.find_largest_ckpt_folder(log_dir + FLAGS.exp_name) summary_dir = log_dir + FLAGS.exp_name + '/' + ckpt_folder + '/decoded' # summary_dir = log_dir + FLAGS.exp_name + '/decode_test_' + str(max_enc_steps) + \ # 'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/decoded' summary_files = sorted(glob.glob(summary_dir + '/*')) if len(summary_files) == 0: raise Exception('No files found in %s' % summary_dir) example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, is_original=True) pros = { 'annotators': 'dcoref', 'outputFormat': 'json', 'timeout': '5000000' } all_merge_examples = [] num_extracted_list = [] distances = [] relative_distances = [] html_str = '' extracted_sents_in_article_html = '' name = FLAGS.dataset_name + '_' + FLAGS.exp_name if FLAGS.coreference_replacement: name += '_coref' highlight_file_name = os.path.join( html_dir, FLAGS.dataset_name + '_' + FLAGS.exp_name) if FLAGS.consider_stopwords: highlight_file_name += '_stopwords' if FLAGS.highlight: extracted_sents_in_article_html_file = open( highlight_file_name + '_extracted_sents.html', 'wb') if FLAGS.kaiqiang: kaiqiang_article_texts = [] kaiqiang_abstract_texts = [] util.create_dirs(kaiqiang_dir) kaiqiang_article_file = open( os.path.join( kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split + '_' + str(FLAGS.min_matched_tokens) + '_articles.txt'), 'wb') kaiqiang_abstract_file = open( os.path.join( kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split + '_' + str(FLAGS.min_matched_tokens) + '_abstracts.txt'), 'wb') if FLAGS.ssi_dataset: if FLAGS.tag_tokens: with_coref_and_ssi_dir = lambdamart_dir + '_and_tag_tokens' else: with_coref_and_ssi_dir = lambdamart_dir lambdamart_out_dir = os.path.join(with_coref_and_ssi_dir, FLAGS.dataset_name) if FLAGS.sentence_limit == 1: lambdamart_out_dir += '_singles' if FLAGS.consider_stopwords: lambdamart_out_dir += '_stopwords' lambdamart_out_full_dir = os.path.join(lambdamart_out_dir, 'all') util.create_dirs(lambdamart_out_full_dir) lambdamart_writer = open( os.path.join(lambdamart_out_full_dir, dataset_split + '.bin'), 'wb') simple_similar_source_indices_list_plus_empty = [] example_idx = -1 instance_idx = 0 total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) random_choices = None if FLAGS.randomize: if FLAGS.dataset_name == 'cnn_dm': list_order = np.random.permutation(11490) random_choices = list_order[:FLAGS.num_instances] for example in tqdm(example_generator, total=total): example_idx += 1 if FLAGS.num_instances != -1 and instance_idx >= FLAGS.num_instances: break if random_choices is not None and example_idx not in random_choices: continue # for file_idx in tqdm(range(len(source_files))): # example = get_tf_example(source_files[file_idx]) article_text = example.features.feature[ 'article'].bytes_list.value[0].decode().lower() if FLAGS.exp_name == 'reference': summary_text, all_summary_texts = get_summary_from_example( example) else: summary_text = get_summary_text(summary_files[example_idx]) article_tokens = split_into_tokens(article_text) if 'raw_article_sents' in example.features.feature and len( example.features.feature['raw_article_sents'].bytes_list. value) > 0: raw_article_sents = example.features.feature[ 'raw_article_sents'].bytes_list.value raw_article_sents = [ sent.decode() for sent in raw_article_sents if sent.decode().strip() != '' ] article_sent_tokens = [ util.process_sent(sent, whitespace=True) for sent in raw_article_sents ] else: # article_text = util.to_unicode(article_text) # sent_pros = {'annotators': 'ssplit', 'outputFormat': 'json', 'timeout': '5000000'} # sents_result_dict = nlp.annotate(str(article_text), properties=sent_pros) # article_sent_tokens = [[token['word'] for token in sent['tokens']] for sent in sents_result_dict['sentences']] raw_article_sents = nltk.tokenize.sent_tokenize(article_text) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if FLAGS.top_n_sents != -1: article_sent_tokens = article_sent_tokens[:FLAGS.top_n_sents] raw_article_sents = raw_article_sents[:FLAGS.top_n_sents] article_sents = [' '.join(sent) for sent in article_sent_tokens] try: article_tokens_string = str(' '.join(article_sents)) except: try: article_tokens_string = str(' '.join( [sent.decode('latin-1') for sent in article_sents])) except: raise if len(article_sent_tokens) == 0: continue summary_sent_tokens = split_into_sent_tokens(summary_text) if 'doc_indices' in example.features.feature and len( example.features.feature['doc_indices'].bytes_list.value ) > 0: doc_indices_str = example.features.feature[ 'doc_indices'].bytes_list.value[0].decode() if '1' in doc_indices_str: doc_indices = [ int(x) for x in doc_indices_str.strip().split() ] rel_sent_positions = importance_features.get_sent_indices( article_sent_tokens, doc_indices) else: num_tokens_total = sum( [len(sent) for sent in article_sent_tokens]) rel_sent_positions = list(range(len(raw_article_sents))) doc_indices = [0] * num_tokens_total else: rel_sent_positions = None doc_indices = None doc_indices_str = None if 'corefs' in example.features.feature and len( example.features.feature['corefs'].bytes_list.value) > 0: corefs_str = example.features.feature[ 'corefs'].bytes_list.value[0] corefs = json.loads(corefs_str) # summary_sent_tokens = limit_to_n_tokens(summary_sent_tokens, 100) similar_source_indices_list_plus_empty = [] simple_similar_source_indices, lcs_paths_list, article_lcs_paths_list, smooth_article_paths_list = ssi_functions.get_simple_source_indices_list( summary_sent_tokens, article_sent_tokens, vocab, FLAGS.sentence_limit, FLAGS.min_matched_tokens, not FLAGS.consider_stopwords, lemmatize=FLAGS.lemmatize, multiple_ssi=FLAGS.multiple_ssi) article_paths_parameter = article_lcs_paths_list if FLAGS.tag_tokens else None article_paths_parameter = smooth_article_paths_list if FLAGS.smart_tags else article_paths_parameter restricted_source_indices = util.enforce_sentence_limit( simple_similar_source_indices, FLAGS.sentence_limit) for summ_sent_idx, summ_sent in enumerate(summary_sent_tokens): if FLAGS.sent_dataset: if len(restricted_source_indices[summ_sent_idx]) == 0: continue merge_example = get_merge_example( restricted_source_indices[summ_sent_idx], article_sent_tokens, summ_sent, corefs, article_paths_parameter[summ_sent_idx]) all_merge_examples.append(merge_example) simple_similar_source_indices_list_plus_empty.append( simple_similar_source_indices) if FLAGS.ssi_dataset: summary_text_to_save = [ s for s in all_summary_texts ] if FLAGS.dataset_name == 'duc_2004' else summary_text write_lambdamart_example(simple_similar_source_indices, raw_article_sents, summary_text_to_save, corefs_str, doc_indices_str, article_paths_parameter, lambdamart_writer) if FLAGS.highlight: highlight_article_lcs_paths_list = smooth_article_paths_list if FLAGS.smart_tags else article_lcs_paths_list # simple_ssi_plus_empty = [ [s[0] for s in sim_source_ind] for sim_source_ind in simple_similar_source_indices] extracted_sents_in_article_html = ssi_functions.html_highlight_sents_in_article( summary_sent_tokens, simple_similar_source_indices, article_sent_tokens, doc_indices, lcs_paths_list, highlight_article_lcs_paths_list) extracted_sents_in_article_html_file.write( extracted_sents_in_article_html.encode()) a = 0 instance_idx += 1 if FLAGS.ssi_dataset: lambdamart_writer.close() if FLAGS.dataset_name == 'cnn_dm' or FLAGS.dataset_name == 'newsroom' or FLAGS.dataset_name == 'xsum': chunk_size = 1000 else: chunk_size = 1 util.chunk_file(dataset_split, lambdamart_out_full_dir, lambdamart_out_dir, chunk_size=chunk_size) if FLAGS.sent_dataset: with_coref_dir = data_dir + '_and_tag_tokens' if FLAGS.tag_tokens else data_dir out_dir = os.path.join(with_coref_dir, FLAGS.dataset_name + '_sent') if FLAGS.sentence_limit == 1: out_dir += '_singles' if FLAGS.consider_stopwords: out_dir += '_stopwords' if FLAGS.coreference_replacement: out_dir += '_coref' if FLAGS.top_n_sents != -1: out_dir += '_n=' + str(FLAGS.top_n_sents) util.create_dirs(out_dir) convert_data.write_with_generator(iter(all_merge_examples), len(all_merge_examples), out_dir, dataset_split) if FLAGS.print_output: # html_str = FLAGS.dataset + ' | ' + FLAGS.exp_name + '<br><br><br>' + html_str # save_fusions_to_file(html_str) ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name) if FLAGS.consider_stopwords: ssi_path += '_stopwords' util.create_dirs(ssi_path) if FLAGS.dataset_name == 'duc_2004' and FLAGS.abstract_idx != 0: abstract_idx_str = '_%d' % FLAGS.abstract_idx else: abstract_idx_str = '' with open( os.path.join( ssi_path, dataset_split + '_ssi' + abstract_idx_str + '.pkl'), 'wb') as f: pickle.dump(simple_similar_source_indices_list_plus_empty, f) if FLAGS.kaiqiang: # kaiqiang_article_file.write('\n'.join(kaiqiang_article_texts)) # kaiqiang_abstract_file.write('\n'.join(kaiqiang_abstract_texts)) kaiqiang_article_file.close() kaiqiang_abstract_file.close() if FLAGS.highlight: extracted_sents_in_article_html_file.close() a = 0