def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) source_dir = os.path.join(data_dir, FLAGS.dataset) source_files = sorted(glob.glob(source_dir + '/*')) for i in range(4): ref_dir = os.path.join(log_dir, 'reference_' + str(i), 'reference') dec_dir = os.path.join(log_dir, 'reference_' + str(i), 'decoded') util.create_dirs(ref_dir) util.create_dirs(dec_dir) for source_idx, source_file in enumerate(source_files): human_summary_texts = get_human_summary_texts(source_file) summaries = [] for summary_text in human_summary_texts: summary = data.abstract2sents(summary_text) summaries.append(summary) candidate = summaries[i] references = [ summaries[idx] for idx in range(len(summaries)) if idx != i ] rouge_functions.write_for_rouge(references, candidate, source_idx, ref_dir, dec_dir) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir) # print("Results_dict: ", results_dict) rouge_functions.rouge_log(results_dict, os.path.join(log_dir, 'reference_' + str(i)))
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) print('Running statistics on %s' % exp_name) start_time = time.time() np.random.seed(random_seed) source_dir = os.path.join(data_dir, dataset_articles) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files)*1000 example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) # Read output of BERT and put into a dictionary with: # key=(article idx, source indices {this is a tuple of length 1 or 2, depending on if it is a singleton or pair}) # value=score qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path) ex_gen = example_generator_extended(example_generator, total, qid_ssi_to_importances, None, FLAGS.singles_and_pairs) print('Creating list') ex_list = [ex for ex in ex_gen] # # Main function to get results on all test examples # pool = mp.Pool(mp.cpu_count()) # ssi_list = list(tqdm(pool.imap(evaluate_example, ex_list), total=total)) # pool.close() # Main function to get results on all test examples ssi_list = list(map(evaluate_example, ex_list)) # save ssi_list with open(os.path.join(my_log_dir, 'ssi.pkl'), 'wb') as f: pickle.dump(ssi_list, f) with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f: ssi_list = pickle.load(f) print('Evaluating BERT model F1 score...') suffix = util.all_sent_selection_eval(ssi_list) print('Evaluating ROUGE...') results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix) ssis_restricted = [ssi_triple[1][:ssi_triple[2]] for ssi_triple in ssi_list] ssi_lens = [len(source_indices) for source_indices in util.flatten_list_of_lists(ssis_restricted)] num_singles = ssi_lens.count(1) num_pairs = ssi_lens.count(2) print ('Percent singles/pairs: %.2f %.2f' % (num_singles*100./len(ssi_lens), num_pairs*100./len(ssi_lens))) util.print_execution_time(start_time)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.summarizer == 'all': summary_methods = list(summarizers.keys()) else: summary_methods = [FLAGS.summarizer] if FLAGS.dataset_name == 'all': dataset_names = datasets else: dataset_names = [FLAGS.dataset_name] sheets_strs = [] for summary_method in summary_methods: summary_fn = summarizers[summary_method] for dataset_name in dataset_names: FLAGS.dataset_name = dataset_name original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else '' vocab = Vocab('logs/vocab' + '_' + original_dataset_name, 50000) # create a vocabulary source_dir = os.path.join(data_dir, dataset_name) source_files = sorted( glob.glob(source_dir + '/' + FLAGS.dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in dataset_name or 'newsroom' in dataset_name or 'xsum' in dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) if dataset_name == 'duc_2004': abs_source_dir = os.path.join( os.path.expanduser('~') + '/data/tf_data/with_coref', dataset_name) abs_example_generator = data.example_generator( abs_source_dir + '/' + FLAGS.dataset_split + '*', True, False, should_check_valid=False) abs_names_to_types = [('abstract', 'string_list')] triplet_ssi_list = [] for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example( example, names_to_types) if dataset_name == 'duc_2004': abs_example = next(abs_example_generator) groundtruth_summary_texts = util.unpack_tf_example( abs_example, abs_names_to_types) groundtruth_summary_texts = groundtruth_summary_texts[0] groundtruth_summ_sents_list = [[ sent.strip() for sent in data.abstract2sents(abstract) ] for abstract in groundtruth_summary_texts] else: groundtruth_summary_texts = [groundtruth_summary_text] groundtruth_summ_sents_list = [] for groundtruth_summary_text in groundtruth_summary_texts: groundtruth_summ_sents = [ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ] groundtruth_summ_sents_list.append( groundtruth_summ_sents) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) log_dir = os.path.join(log_root, dataset_name + '_' + summary_method) dec_dir = os.path.join(log_dir, 'decoded') ref_dir = os.path.join(log_dir, 'reference') util.create_dirs(dec_dir) util.create_dirs(ref_dir) parser = PlaintextParser.from_string( ' '.join(raw_article_sents), Tokenizer("english")) summarizer = summary_fn() summary = summarizer( parser.document, 5) #Summarize the document with 5 sentences summary = [str(sentence) for sentence in summary] summary_tokenized = [] for sent in summary: summary_tokenized.append(sent.lower()) rouge_functions.write_for_rouge(groundtruth_summ_sents_list, summary_tokenized, example_idx, ref_dir, dec_dir, log=False) decoded_sent_tokens = [ sent.split() for sent in summary_tokenized ] sentence_limit = 2 sys_ssi_list, _, _ = get_simple_source_indices_list( decoded_sent_tokens, article_sent_tokens, vocab, sentence_limit, min_matched_tokens) triplet_ssi_list.append( (groundtruth_similar_source_indices_list, sys_ssi_list, -1)) print('Evaluating Lambdamart model F1 score...') suffix = util.all_sent_selection_eval(triplet_ssi_list) print(suffix) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir) print(("Results_dict: ", results_dict)) sheets_str = rouge_functions.rouge_log(results_dict, log_dir, suffix=suffix) sheets_strs.append(dataset_name + '_' + summary_method + '\n' + sheets_str) for sheets_str in sheets_strs: print(sheets_str + '\n')
def main(unused_argv): # def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) print('Running statistics on %s' % exp_name) start_time = time.time() np.random.seed(random_seed) source_dir = os.path.join(data_dir, dataset_articles) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) ex_sents = ['single .', 'sentence .'] article_text = ' '.join(ex_sents) sent_term_matrix = util.get_doc_substituted_tfidf_matrix( tfidf_vectorizer, ex_sents, article_text, pca) if FLAGS.singles_and_pairs == 'pairs': single_feat_len = 0 else: single_feat_len = len( get_single_sent_features(0, sent_term_matrix, [['single', '.'], ['sentence', '.']], [0, 0], 0)) if FLAGS.singles_and_pairs == 'singles': pair_feat_len = 0 else: pair_feat_len = len( get_pair_sent_features([0, 1], sent_term_matrix, [['single', '.'], ['sentence', '.']], [0, 0], [0, 0])) total = len(source_files ) * 1000 if 'cnn' or 'newsroom' in dataset_articles else len( source_files) example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) if FLAGS.mode == 'write_to_file': ex_gen = example_generator_extended(example_generator, total, single_feat_len, pair_feat_len, FLAGS.singles_and_pairs) print('Creating list') ex_list = [ex for ex in ex_gen] print('Converting...') # if len(sys.argv) > 1 and sys.argv[1] == '-m': list(futures.map(write_to_lambdamart_examples_to_file, ex_list)) # else: # instances_list = [] # for ex in tqdm(ex_list): # instances_list.append(write_to_lambdamart_examples_to_file(ex)) file_names = sorted(glob.glob(os.path.join(temp_in_dir, '*'))) instances_str = '' for file_name in tqdm(file_names): with open(file_name) as f: instances_str += f.read() with open(temp_in_path, 'wb') as f: f.write(instances_str) # RUN LAMBDAMART SCORING COMMAND HERE if FLAGS.mode == 'generate_summaries': qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path) ex_gen = example_generator_extended(example_generator, total, qid_ssi_to_importances, pair_feat_len, FLAGS.singles_and_pairs) print('Creating list') ex_list = [ex for ex in ex_gen] ssi_list = list(futures.map(evaluate_example, ex_list)) # save ssi_list with open(os.path.join(my_log_dir, 'ssi.pkl'), 'w') as f: pickle.dump(ssi_list, f) with open(os.path.join(my_log_dir, 'ssi.pkl')) as f: ssi_list = pickle.load(f) print('Evaluating Lambdamart model F1 score...') suffix = util.all_sent_selection_eval(ssi_list) # # # for ex in tqdm(ex_list, total=total): # # load_and_evaluate_example(ex) # print('Evaluating ROUGE...') results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=l_param) # print("Results_dict: ", results_dict) rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix) util.print_execution_time(start_time)
def main(unused_argv): print('Running statistics on %s' % FLAGS.dataset_name) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) out_dir = os.path.join( os.path.expanduser('~') + '/data/kaiqiang_data', FLAGS.dataset_name) if FLAGS.mode == 'write': util.create_dirs(out_dir) if FLAGS.dataset_name == 'duc_2004': dataset_splits = ['test'] elif FLAGS.dataset_split == 'all': dataset_splits = ['test', 'val', 'train'] else: dataset_splits = [FLAGS.dataset_split] for dataset_split in dataset_splits: if dataset_split == 'test': ssi_data_path = os.path.join( 'logs/%s_bert_both_sentemb_artemb_plushidden' % FLAGS.dataset_name, 'ssi.pkl') print(util.bcolors.OKGREEN + "Loading SSI from BERT at %s" % ssi_data_path + util.bcolors.ENDC) with open(ssi_data_path) as f: ssi_triple_list = pickle.load(f) source_dir = os.path.join(data_dir, FLAGS.dataset_name) source_files = sorted( glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 if ( 'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files) example_generator = data.example_generator( source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) out_document_path = os.path.join(out_dir, dataset_split + '.Ndocument') out_summary_path = os.path.join(out_dir, dataset_split + '.Nsummary') out_example_idx_path = os.path.join(out_dir, dataset_split + '.Nexampleidx') doc_writer = open(out_document_path, 'w') if dataset_split != 'test': sum_writer = open(out_summary_path, 'w') ex_idx_writer = open(out_example_idx_path, 'w') for example_idx, example in enumerate( tqdm(example_generator, total=total)): if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances: break raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if doc_indices is None: doc_indices = [0] * len( util.flatten_list_of_lists(article_sent_tokens)) doc_indices = [int(doc_idx) for doc_idx in doc_indices] # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens) if dataset_split == 'test': if example_idx >= len(ssi_triple_list): raise Exception( 'Len of ssi list (%d) is less than number of examples (>=%d)' % (len(ssi_triple_list), example_idx)) ssi_length_extractive = ssi_triple_list[example_idx][2] if ssi_length_extractive > 1: a = 0 ssi = ssi_triple_list[example_idx][1] ssi = ssi[:ssi_length_extractive] groundtruth_similar_source_indices_list = ssi else: groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, FLAGS.sentence_limit) for ssi_idx, ssi in enumerate( groundtruth_similar_source_indices_list): if len(ssi) == 0: continue my_article = ' '.join(util.reorder(raw_article_sents, ssi)) doc_writer.write(my_article + '\n') if dataset_split != 'test': sum_writer.write(groundtruth_summ_sents[0][ssi_idx] + '\n') ex_idx_writer.write(str(example_idx) + '\n') elif FLAGS.mode == 'evaluate': summary_dir = '/home/logan/data/kaiqiang_data/logan_ACL/trained_on_' + FLAGS.train_dataset + '/' + FLAGS.dataset_name out_summary_path = os.path.join(summary_dir, 'test' + 'Summary.txt') out_example_idx_path = os.path.join(out_dir, 'test' + '.Nexampleidx') decode_dir = 'logs/kaiqiang_%s_trainedon%s' % (FLAGS.dataset_name, FLAGS.train_dataset) rouge_ref_dir = os.path.join(decode_dir, 'reference') rouge_dec_dir = os.path.join(decode_dir, 'decoded') util.create_dirs(rouge_ref_dir) util.create_dirs(rouge_dec_dir) def num_lines_in_file(file_path): with open(file_path) as f: num_lines = sum(1 for line in f) return num_lines def process_example(sents, ex_idx, groundtruth_summ_sents): final_decoded_words = [] for sent in sents: final_decoded_words.extend(sent.split(' ')) rouge_functions.write_for_rouge(groundtruth_summ_sents, None, ex_idx, rouge_ref_dir, rouge_dec_dir, decoded_words=final_decoded_words, log=False) num_lines_summary = num_lines_in_file(out_summary_path) num_lines_example_indices = num_lines_in_file(out_example_idx_path) if num_lines_summary != num_lines_example_indices: raise Exception( 'Num lines summary != num lines example indices: (%d, %d)' % (num_lines_summary, num_lines_example_indices)) source_dir = os.path.join(data_dir, FLAGS.dataset_name) example_generator = data.example_generator(source_dir + '/' + 'test' + '*', True, False, should_check_valid=False) sum_writer = open(out_summary_path) ex_idx_writer = open(out_example_idx_path) prev_ex_idx = 0 sents = [] for line_idx in tqdm(range(num_lines_summary)): line = sum_writer.readline() ex_idx = int(ex_idx_writer.readline()) if ex_idx == prev_ex_idx: sents.append(line) else: example = example_generator.next() raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] process_example(sents, ex_idx, groundtruth_summ_sents) prev_ex_idx = ex_idx sents = [line] example = example_generator.next() raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example( example, names_to_types) if FLAGS.dataset_name == 'duc_2004': groundtruth_summ_sents = [[ sent.strip() for sent in gt_summ_text.strip().split('\n') ] for gt_summ_text in groundtruth_summary_text] else: groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] process_example(sents, ex_idx, groundtruth_summ_sents) print("Now starting ROUGE eval...") if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 results_dict = rouge_functions.rouge_eval(rouge_ref_dir, rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, decode_dir) else: raise Exception('mode flag was not evaluate or write.')
def decode(self): """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals""" t0 = time.time() counter = 0 total = len(glob.glob(self._batcher._data_path)) * 1000 pbar = tqdm(total=total) while True: batch = self._batcher.next_batch( ) # 1 example repeated across batch if batch is None: # finished decoding dataset in single_pass mode assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode" logging.info( "Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval( self._rouge_ref_dir, self._rouge_dec_dir) rouge_functions.rouge_log(results_dict, self._decode_dir) return original_article = batch.original_articles[0] # string original_abstract = batch.original_abstracts[0] # string all_original_abstract_sents = batch.all_original_abstracts_sents[0] raw_article_sents = batch.raw_article_sents[0] article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs( original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, counter, self._batcher._hps) if FLAGS.single_pass: if counter < 1000: self.write_for_human(raw_article_sents, all_original_abstract_sents, decoded_words, counter) rouge_functions.write_for_rouge( all_original_abstract_sents, None, counter, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=decoded_words ) # write ref summary and decoded summary to file, to eval with pyrouge later if FLAGS.attn_vis: self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, counter ) # write info to .json file for visualization tool counter += 1 # this is how many examples we've decoded else: print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, counter) # write info to .json file for visualization tool # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint t1 = time.time() if t1 - t0 > SECS_UNTIL_NEW_CKPT: logging.info( 'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1 - t0) _ = util.load_ckpt(self._saver, self._sess) t0 = time.time() pbar.update(1) pbar.close()
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents) else: gt_ssi, sys_ssi, ext_len = ssi_list[example_idx] if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) if gt_ssi != groundtruth_similar_source_indices_list: print( 'Warning: Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '' for ssi_idx, ssi in enumerate(sys_ssi): selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, FLAGS.batch_size, hps, self._vocab) decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 1000: min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>' if len(final_decoded_words) >= 100: break if example_idx < 1000: self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def decode_iteratively(self, example_generator, total, names_to_types, ssi_list, hps): attn_vis_idx = 0 for example_idx, example in enumerate( tqdm(example_generator, total=total)): raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example( example, names_to_types) article_sent_tokens = [ util.process_sent(sent) for sent in raw_article_sents ] groundtruth_summ_sents = [[ sent.strip() for sent in groundtruth_summary_text.strip().split('\n') ]] groundtruth_summ_sent_tokens = [ sent.split(' ') for sent in groundtruth_summ_sents[0] ] if ssi_list is None: # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) sys_ssi = groundtruth_similar_source_indices_list sys_alp_list = groundtruth_article_lcs_paths_list if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) sys_ssi, sys_alp_list = util.replace_empty_ssis( sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list) else: gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[ example_idx] sys_alp_list = ssi_functions.list_labels_from_probs( sys_token_probs_list, FLAGS.tag_threshold) if FLAGS.singles_and_pairs == 'singles': sys_ssi = util.enforce_sentence_limit(sys_ssi, 1) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 1) gt_ssi = util.enforce_sentence_limit(gt_ssi, 1) elif FLAGS.singles_and_pairs == 'both': sys_ssi = util.enforce_sentence_limit(sys_ssi, 2) sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2) groundtruth_similar_source_indices_list = util.enforce_sentence_limit( groundtruth_similar_source_indices_list, 2) gt_ssi = util.enforce_sentence_limit(gt_ssi, 2) # if gt_ssi != groundtruth_similar_source_indices_list: # raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi)) if FLAGS.dataset_name == 'xsum': sys_ssi = [sys_ssi[0]] final_decoded_words = [] final_decoded_outpus = '' best_hyps = [] highlight_html_total = '<u>System Summary</u><br><br>' for ssi_idx, ssi in enumerate(sys_ssi): # selected_article_lcs_paths = None selected_article_lcs_paths = sys_alp_list[ssi_idx] ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) selected_article_lcs_paths = [selected_article_lcs_paths] selected_raw_article_sents = util.reorder( raw_article_sents, ssi) selected_article_text = ' '.join([ ' '.join(sent) for sent in util.reorder(article_sent_tokens, ssi) ]) selected_doc_indices_str = '0 ' * len( selected_article_text.split()) if FLAGS.upper_bound: selected_groundtruth_summ_sent = [[ groundtruth_summ_sents[0][ssi_idx] ]] else: selected_groundtruth_summ_sent = groundtruth_summ_sents batch = create_batch(selected_article_text, selected_groundtruth_summ_sent, selected_doc_indices_str, selected_raw_article_sents, selected_article_lcs_paths, FLAGS.batch_size, hps, self._vocab) original_article = batch.original_articles[0] # string original_abstract = batch.original_abstracts[0] # string article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs( original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string # article_withunks = data.show_art_oovs(original_article, self._vocab) # string # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string if FLAGS.first_intact and ssi_idx == 0: decoded_words = selected_article_text.strip().split() decoded_output = selected_article_text else: decoded_words, decoded_output, best_hyp = decode_example( self._sess, self._model, self._vocab, batch, example_idx, hps) best_hyps.append(best_hyp) final_decoded_words.extend(decoded_words) final_decoded_outpus += decoded_output if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [decoded_words] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_total += highlighted_html + '<br>' if FLAGS.attn_vis and example_idx < 200: self.write_for_attnvis( article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx ) # write info to .json file for visualization tool attn_vis_idx += 1 if len(final_decoded_words) >= 100: break gt_ssi_list, gt_alp_list = util.replace_empty_ssis( groundtruth_similar_source_indices_list, raw_article_sents, sys_alp_list=groundtruth_article_lcs_paths_list) highlight_html_gt = '<u>Reference Summary</u><br><br>' for ssi_idx, ssi in enumerate(gt_ssi_list): selected_article_lcs_paths = gt_alp_list[ssi_idx] try: ssi, selected_article_lcs_paths = util.make_ssi_chronological( ssi, selected_article_lcs_paths) except: util.print_vars(ssi, example_idx, selected_article_lcs_paths) raise selected_raw_article_sents = util.reorder( raw_article_sents, ssi) if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): min_matched_tokens = 2 selected_article_sent_tokens = [ util.process_sent(sent) for sent in selected_raw_article_sents ] highlight_summary_sent_tokens = [ groundtruth_summ_sent_tokens[ssi_idx] ] highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list( highlight_summary_sent_tokens, selected_article_sent_tokens, None, 2, min_matched_tokens) highlighted_html = ssi_functions.html_highlight_sents_in_article( highlight_summary_sent_tokens, highlight_ssi_list, selected_article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list= highlight_smooth_article_lcs_paths_list) highlight_html_gt += highlighted_html + '<br>' if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100): self.write_for_human(raw_article_sents, groundtruth_summ_sents, final_decoded_words, example_idx) highlight_html_total = ssi_functions.put_html_in_two_columns( highlight_html_total, highlight_html_gt) ssi_functions.write_highlighted_html(highlight_html_total, self._highlight_dir, example_idx) # if example_idx % 100 == 0: # attn_dir = os.path.join(self._decode_dir, 'attn_vis_data') # attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab) rouge_functions.write_for_rouge( groundtruth_summ_sents, None, example_idx, self._rouge_ref_dir, self._rouge_dec_dir, decoded_words=final_decoded_words, log=False ) # write ref summary and decoded summary to file, to eval with pyrouge later # if FLAGS.attn_vis: # self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool example_idx += 1 # this is how many examples we've decoded logging.info("Decoder has finished reading dataset for single_pass.") logging.info("Output has been saved in %s and %s.", self._rouge_ref_dir, self._rouge_dec_dir) if len(os.listdir(self._rouge_ref_dir)) != 0: if FLAGS.dataset_name == 'xsum': l_param = 100 else: l_param = 100 logging.info("Now starting ROUGE eval...") results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir, l_param=l_param) rouge_functions.rouge_log(results_dict, self._decode_dir)
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.exp_name == 'extractive': for summary_method in summary_methods: if not os.path.exists( os.path.join(out_dir, summary_method, 'decoded')): os.makedirs(os.path.join(out_dir, summary_method, 'decoded')) if not os.path.exists( os.path.join(out_dir, summary_method, 'reference')): os.makedirs(os.path.join(out_dir, summary_method, 'reference')) print((os.path.join(out_dir, summary_method))) method_dir = os.path.join(summaries_dir, summary_method) file_names = sorted( [name for name in os.listdir(method_dir) if name[0] == 'd']) for art_idx, article_name in enumerate(tqdm(file_names)): file = os.path.join(method_dir, article_name) with open(file, 'rb') as f: lines = f.readlines() tokenized_sents = [[ token.lower() for token in nltk.tokenize.word_tokenize(line) ] for line in lines] sentences = [' '.join(sent) for sent in tokenized_sents] processed_summary = '\n'.join(sentences) out_name = '%06d_decoded.txt' % art_idx with open( os.path.join(out_dir, summary_method, 'decoded', out_name), 'wb') as f: f.write(processed_summary) reference_files = glob.glob( os.path.join(ref_dir, '%06d' % art_idx + '*')) abstract_sentences = [] for ref_file in reference_files: with open(ref_file) as f: lines = f.readlines() abstract_sentences.append(lines) rouge_functions.write_for_rouge( abstract_sentences, sentences, art_idx, os.path.join(out_dir, summary_method, 'reference'), os.path.join(out_dir, summary_method, 'decoded')) results_dict = rouge_functions.rouge_eval( ref_dir, os.path.join(out_dir, summary_method, 'decoded')) # print("Results_dict: ", results_dict) rouge_functions.rouge_log(results_dict, os.path.join(out_dir, summary_method)) for summary_method in summary_methods: print(summary_method) all_results = '' for summary_method in summary_methods: sheet_results_file = os.path.join(out_dir, summary_method, "sheets_results.txt") with open(sheet_results_file) as f: results = f.read() all_results += results print(all_results) a = 0 else: # source_dir = os.path.join(data_dir, FLAGS.dataset) log_root = os.path.join('logs', FLAGS.exp_name) ckpt_folder = util.find_largest_ckpt_folder(log_root) print(ckpt_folder) # if os.path.exists(os.path.join(log_dir,FLAGS.exp_name,ckpt_folder)): if ckpt_folder != 'decoded': summary_dir = os.path.join(log_dir, FLAGS.exp_name, ckpt_folder) else: summary_dir = os.path.join(log_dir, FLAGS.exp_name) ref_dir = os.path.join(summary_dir, 'reference') dec_dir = os.path.join(summary_dir, 'decoded') summary_files = glob.glob(os.path.join(dec_dir, '*_decoded.txt')) # summary_files = glob.glob(os.path.join(log_dir + FLAGS.exp_name, 'test_*.txt.result.summary')) # if len(summary_files) > 0 and not os.path.exists(dec_dir): # reformat files from extract + rewrite # os.makedirs(ref_dir) # os.makedirs(dec_dir) # for summary_file in tqdm(summary_files): # ex_index = extract_digits(os.path.basename(summary_file)) # new_file = os.path.join(dec_dir, "%06d_decoded.txt" % ex_index) # shutil.copyfile(summary_file, new_file) # # ref_files_to_copy = glob.glob(os.path.join(log_dir, FLAGS.dataset, ckpt_folder, 'reference', '*')) # for file in tqdm(ref_files_to_copy): # basename = os.path.basename(file) # shutil.copyfile(file, os.path.join(ref_dir, basename)) lengths = [] for summary_file in tqdm(summary_files): with open(summary_file) as f: summary = f.read() length = len(summary.strip().split()) lengths.append(length) print('Average summary length: %.2f' % np.mean(lengths)) print('Evaluating on %d files' % len(os.listdir(dec_dir))) results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=FLAGS.l_param) rouge_functions.rouge_log(results_dict, summary_dir)
def main(unused_argv): # def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) print('Running statistics on %s' % exp_name) start_time = time.time() np.random.seed(random_seed) source_dir = os.path.join(data_dir, dataset_articles) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) ex_sents = ['single .', 'sentence .'] article_text = ' '.join(ex_sents) total = len(source_files) * 1000 example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path) qid_ssi_to_token_scores_and_mappings = get_token_scores_for_ssi( temp_in_path, file_path_seq, file_path_mappings) ex_gen = example_generator_extended(example_generator, total, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings) print('Creating list') ex_list = [ex for ex in ex_gen] ssi_list = list(futures.map(evaluate_example, ex_list)) # save ssi_list with open(os.path.join(my_log_dir, 'ssi.pkl'), 'wb') as f: pickle.dump(ssi_list, f) with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f: ssi_list = pickle.load(f) print('Evaluating BERT model F1 score...') suffix = util.all_sent_selection_eval(ssi_list) # # # for ex in tqdm(ex_list, total=total): # # load_and_evaluate_example(ex) # print('Evaluating ROUGE...') results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=l_param) # print("Results_dict: ", results_dict) rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix) ssis_restricted = [ ssi_triple[1][:ssi_triple[2]] for ssi_triple in ssi_list ] ssi_lens = [ len(source_indices) for source_indices in util.flatten_list_of_lists(ssis_restricted) ] # print ssi_lens num_singles = ssi_lens.count(1) num_pairs = ssi_lens.count(2) print( 'Percent singles/pairs: %.2f %.2f' % (num_singles * 100. / len(ssi_lens), num_pairs * 100. / len(ssi_lens))) util.print_execution_time(start_time)