def process_one_example(example):
    raw_article_sents, groundtruth_summ_sents, example_idx, pretty_html_path, out_full_dir = example
    article_sent_tokens = [util.process_sent(sent, whitespace=True) for sent in raw_article_sents]
    doc_indices = None
    if doc_indices is None or (FLAGS.dataset_name != 'duc_2004' and len(doc_indices) != len(
            util.flatten_list_of_lists(article_sent_tokens))):
        doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens))
    doc_indices = [int(doc_idx) for doc_idx in doc_indices]
    summary_sent_tokens = [util.process_sent(sent, whitespace=True) for sent in groundtruth_summ_sents]

    writer = open(os.path.join(out_full_dir, '%06d_ssi.tsv' % example_idx), 'wb')

    if len(article_sent_tokens) == 0:
        print('Skipping because empty')
        writer.write('\n')
        return

    ''' This is the main function that finds the article sentences that were fused to create the given summary sentence'''
    simple_similar_source_indices, lcs_paths_list, smooth_article_paths_list =  ssi_functions.get_simple_source_indices_list(
        summary_sent_tokens, article_sent_tokens, vocab=None, sentence_limit=FLAGS.sentence_limit, min_matched_tokens=FLAGS.min_matched_tokens)

    highlight_line = '\t'.join([','.join([str(src_idx) for src_idx in source_indices]) for source_indices in simple_similar_source_indices]) + '\n'
    writer.write(highlight_line.encode())

    if example_idx < 1:
        f_pretty_html = open(pretty_html_path, 'wb')
        extracted_sents_in_article_html = ssi_functions.html_highlight_sents_in_article(summary_sent_tokens, simple_similar_source_indices,
                                                                          article_sent_tokens, doc_indices,
                                                                          lcs_paths_list, smooth_article_paths_list)
        f_pretty_html.write(extracted_sents_in_article_html.encode())
Esempio n. 2
0
def generate_summary_importance(raw_article_sents, article_sent_tokens,
                                temp_in_path, temp_out_path, single_feat_len,
                                pair_feat_len):
    tfidfs = util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents)
    instances = get_features_all_combinations(raw_article_sents,
                                              article_sent_tokens, tfidfs,
                                              single_feat_len, pair_feat_len)
    source_indices_to_importances = rank_source_sents(instances, temp_in_path,
                                                      temp_out_path,
                                                      single_feat_len)
    summary_sent_tokens = []
    summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
    already_used_source_indices = []
    similar_source_indices_list = []
    summary_sents_for_html = []
    while len(summary_tokens) < 120:
        mmr_dict = util.calc_MMR_source_indices(article_sent_tokens,
                                                summary_tokens, None,
                                                source_indices_to_importances)
        sents, source_indices = get_best_source_sents(
            article_sent_tokens, mmr_dict, already_used_source_indices)
        if len(source_indices) == 0:
            break
        summary_sent_tokens.extend(sents)
        summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
        similar_source_indices_list.append(source_indices)
        summary_sents_for_html.append(' <br> '.join(
            [' '.join(sent) for sent in sents]))
        if filter_sentences:
            already_used_source_indices.extend(source_indices)
    summary_sents = [' '.join(sent) for sent in summary_sent_tokens]
    # summary = '\n'.join([' '.join(tokens) for tokens in summary_sent_tokens])
    return summary_sents, similar_source_indices_list, summary_sents_for_html
Esempio n. 3
0
def generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx):
    qid = example_idx

    summary_sent_tokens = []
    summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
    already_used_source_indices = []
    similar_source_indices_list = []
    summary_sents_for_html = []
    ssi_length_extractive = None

    # Iteratively select a singleton/pair from the article that has the highest score from BERT
    while len(summary_tokens) < 300:
        if len(summary_tokens) >= l_param and ssi_length_extractive is None:
            ssi_length_extractive = len(similar_source_indices_list)
        mmr_dict = util.calc_MMR_source_indices(article_sent_tokens, summary_tokens, None, qid_ssi_to_importances, qid=qid)
        sents, source_indices = get_best_source_sents(article_sent_tokens, mmr_dict, already_used_source_indices)
        if len(source_indices) == 0:
            break
        summary_sent_tokens.extend(sents)
        summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
        similar_source_indices_list.append(source_indices)
        summary_sents_for_html.append(' <br> '.join([' '.join(sent) for sent in sents]))
        if filter_sentences:
            already_used_source_indices.extend(source_indices)
    if ssi_length_extractive is None:
        ssi_length_extractive = len(similar_source_indices_list)
    selected_article_sent_indices = util.flatten_list_of_lists(similar_source_indices_list[:ssi_length_extractive])
    summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)]
    return summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive
Esempio n. 4
0
 def get_enc_importances(self, tokenized_sents, abstract_words):
     lemmatize = True
     if lemmatize:
         article_sent_tokens_lemma = util.lemmatize_sent_tokens(
             tokenized_sents)
         summary_sent_tokens_lemma = util.lemmatize_sent_tokens(
             [abstract_words])
     article_tokens = util.flatten_list_of_lists(article_sent_tokens_lemma)
     abstract_tokens = util.flatten_list_of_lists(summary_sent_tokens_lemma)
     enc_importances = [
         1. if token in abstract_tokens else 0. for token in article_tokens
     ]
     return enc_importances
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.dataset_split == 'all':
        dataset_splits = ['test', 'val', 'train']
    else:
        dataset_splits = [FLAGS.dataset_split]

    vocab_counter = collections.Counter()

    for dataset_split in dataset_splits:

        source_dir = os.path.join(FLAGS.data_root, FLAGS.dataset_name)
        source_files = sorted(glob.glob(source_dir + '/' + dataset_split +
                                        '*'))

        total = len(source_files) * 1000
        example_generator = data.example_generator(source_dir + '/' +
                                                   dataset_split + '*',
                                                   True,
                                                   False,
                                                   should_check_valid=False)

        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):

            raw_article_sents, article, abstracts, doc_indices = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            # groundtruth_summ_sent_tokens = [sent.strip().split() for sent in groundtruth_summary_text.strip().split('\n')]
            groundtruth_summ_sent_tokens = [[
                token for token in abstract.strip().split()
                if token not in ['<s>', '</s>']
            ] for abstract in abstracts]
            all_tokens = util.flatten_list_of_lists(
                article_sent_tokens) + util.flatten_list_of_lists(
                    groundtruth_summ_sent_tokens)

            vocab_counter.update(all_tokens)

    print("Writing vocab file...")
    with open(os.path.join('logs', "vocab_" + FLAGS.dataset_name),
              'w') as writer:
        for word, count in vocab_counter.most_common(VOCAB_SIZE):
            writer.write(word + ' ' + str(count) + '\n')
    print("Finished writing vocab file")
def get_rel_sent_indices(doc_indices, article_sent_tokens):
    if FLAGS.dataset_name != 'duc_2004' and len(doc_indices) != len(util.flatten_list_of_lists(article_sent_tokens)):
        doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens))
    doc_indices_sent_tokens = util.reshape_like(doc_indices, article_sent_tokens)
    sent_doc = [sent[0] for sent in doc_indices_sent_tokens]
    rel_sent_indices = []
    doc_sent_indices = []
    cur_doc_idx = 0
    rel_sent_idx = 0
    for doc_idx in sent_doc:
        if doc_idx != cur_doc_idx:
            rel_sent_idx = 0
            cur_doc_idx = doc_idx
        rel_sent_indices.append(rel_sent_idx)
        doc_sent_indices.append(cur_doc_idx)
        rel_sent_idx += 1
    doc_sent_lens = [sum(1 for my_doc_idx in doc_sent_indices if my_doc_idx == doc_idx) for doc_idx in
                     range(max(doc_sent_indices) + 1)]
    return rel_sent_indices, doc_sent_indices, doc_sent_lens
Esempio n. 7
0
def convert_singpairmix_to_tf_examples(dataset_name, processed_data_dir, tf_example_dir, dataset_split='all'):
    out_dir = os.path.join(tf_example_dir, dataset_name)
    out_full_dir = os.path.join(out_dir, 'all')
    util.create_dirs(out_full_dir)
    if dataset_split == 'all':
        if dataset_name == 'duc_2004':
            dataset_splits = ['test']
        else:
            dataset_splits = ['test', 'val', 'train']
    else:
        dataset_splits = [dataset_split]
    for dataset_split in dataset_splits:
        processed_data_path = os.path.join(processed_data_dir, dataset_name, dataset_split)
        articles_path = os.path.join(processed_data_path,'articles.tsv')
        abstracts_path = os.path.join(processed_data_path,'summaries.tsv')
        highlight_path = os.path.join(processed_data_path,'highlight.tsv')

        f_art = open(articles_path)
        f_abs = open(abstracts_path)
        f_hl = open(highlight_path)
        writer = open(os.path.join(out_full_dir, dataset_split + '.bin'), 'wb')
        total = util.num_lines_in_file(articles_path)
        for example_idx in tqdm(range(total)):
            raw_article_sents = f_art.readline().strip().split('\t')
            groundtruth_summ_sents = f_abs.readline().strip().split('\t')
            summary_text = '\n'.join(groundtruth_summ_sents)
            article_sent_tokens = [util.process_sent(sent, whitespace=True) for sent in raw_article_sents]
            doc_indices = None
            if doc_indices is None or (dataset_name != 'duc_2004' and len(doc_indices) != len(
                    util.flatten_list_of_lists(article_sent_tokens))):
                doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens))
            doc_indices_str = ' '.join([str(idx) for idx in doc_indices])
            similar_source_indices = [source_indices.split(',') for source_indices in f_hl.readline().split('\t')]

            write_bert_tf_example(similar_source_indices, raw_article_sents, summary_text, None,
                                  doc_indices_str, None, writer, dataset_name)

        writer.close()
        if dataset_name == 'cnn_dm' or dataset_name == 'newsroom' or dataset_name == 'xsum':
            chunk_size = 1000
        else:
            chunk_size = 1
        util.chunk_file(dataset_split, out_full_dir, out_dir, chunk_size=chunk_size)
def write_to_lambdamart_examples_to_file(ex):
    example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs = ex
    print(example_idx)
    # example_idx += 1
    temp_in_path = os.path.join(temp_in_dir, '%06d.txt' % example_idx)
    if not FLAGS.start_over and os.path.exists(temp_in_path):
        return
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    if doc_indices is None:
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    doc_indices = [int(doc_idx) for doc_idx in doc_indices]
    if len(doc_indices) != len(
            util.flatten_list_of_lists(article_sent_tokens)):
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    rel_sent_indices, _, _ = get_rel_sent_indices(doc_indices,
                                                  article_sent_tokens)
    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sents = [[
        sent.strip() for sent in groundtruth_summary_text.strip().split('\n')
    ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]
    # summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')]

    if FLAGS.dataset_name == 'duc_2004':
        first_k_indices = get_indices_of_first_k_sents_of_each_article(
            rel_sent_indices, FLAGS.first_k)
    else:
        first_k_indices = [idx for idx in range(len(raw_article_sents))]

    if importance:
        get_instances(example_idx, raw_article_sents, article_sent_tokens,
                      corefs, rel_sent_indices, first_k_indices, temp_in_path,
                      temp_out_path, single_feat_len, pair_feat_len,
                      singles_and_pairs)
Esempio n. 9
0
def convert_to_word_level(mmr_for_sentences, enc_tokens):
    num_tokens = len(util.flatten_list_of_lists(enc_tokens))
    mmr = np.ones([num_tokens], dtype=float) / num_tokens
    # Calculate how much for each word in source
    word_idx = 0
    for sent_idx in range(len(enc_tokens)):
        mmr_for_words = np.full([len(enc_tokens[sent_idx])],
                                mmr_for_sentences[sent_idx])
        mmr[word_idx:word_idx + len(mmr_for_words)] = mmr_for_words
        word_idx += len(mmr_for_words)
    return mmr
def generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx):
    qid = example_idx

    summary_sent_tokens = []
    summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
    already_used_source_indices = []
    similar_source_indices_list = []
    summary_sents_for_html = []
    ssi_length_extractive = None
    while len(summary_tokens) < 1000:
        if len(summary_tokens) >= l_param and ssi_length_extractive is None:
            ssi_length_extractive = len(similar_source_indices_list)
        if FLAGS.dataset_name == 'xsum' and len(summary_tokens) > 0:
            ssi_length_extractive = len(similar_source_indices_list)
            break
        mmr_dict = util.calc_MMR_source_indices(article_sent_tokens,
                                                summary_tokens,
                                                None,
                                                qid_ssi_to_importances,
                                                qid=qid)
        sents, source_indices = get_best_source_sents(
            article_sent_tokens, mmr_dict, already_used_source_indices)
        if len(source_indices) == 0:
            break
        summary_sent_tokens.extend(sents)
        summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
        similar_source_indices_list.append(source_indices)
        summary_sents_for_html.append(' <br> '.join(
            [' '.join(sent) for sent in sents]))
        if filter_sentences:
            already_used_source_indices.extend(source_indices)
    if ssi_length_extractive is None:
        ssi_length_extractive = len(similar_source_indices_list)
    selected_article_sent_indices = util.flatten_list_of_lists(
        similar_source_indices_list[:ssi_length_extractive])
    summary_sents = [
        ' '.join(sent) for sent in util.reorder(article_sent_tokens,
                                                selected_article_sent_indices)
    ]
    # summary = '\n'.join([' '.join(tokens) for tokens in summary_sent_tokens])
    return summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive
Esempio n. 11
0
def main(unused_argv):

    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    print('Running statistics on %s' % exp_name)

    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))


    total = len(source_files)*1000
    example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False)

    # Read output of BERT and put into a dictionary with:
    # key=(article idx, source indices {this is a tuple of length 1 or 2, depending on if it is a singleton or pair})
    # value=score
    qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path)
    ex_gen = example_generator_extended(example_generator, total, qid_ssi_to_importances, None, FLAGS.singles_and_pairs)
    print('Creating list')
    ex_list = [ex for ex in ex_gen]

    # # Main function to get results on all test examples
    # pool = mp.Pool(mp.cpu_count())
    # ssi_list = list(tqdm(pool.imap(evaluate_example, ex_list), total=total))
    # pool.close()

    # Main function to get results on all test examples
    ssi_list = list(map(evaluate_example, ex_list))

    # save ssi_list
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'wb') as f:
        pickle.dump(ssi_list, f)
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
        ssi_list = pickle.load(f)
    print('Evaluating BERT model F1 score...')
    suffix = util.all_sent_selection_eval(ssi_list)
    print('Evaluating ROUGE...')
    results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=l_param)
    rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix)

    ssis_restricted = [ssi_triple[1][:ssi_triple[2]] for ssi_triple in ssi_list]
    ssi_lens = [len(source_indices) for source_indices in util.flatten_list_of_lists(ssis_restricted)]
    num_singles = ssi_lens.count(1)
    num_pairs = ssi_lens.count(2)
    print ('Percent singles/pairs: %.2f %.2f' % (num_singles*100./len(ssi_lens), num_pairs*100./len(ssi_lens)))

    util.print_execution_time(start_time)
Esempio n. 12
0
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, _, _ = ex
    print(example_idx)

    # Read example from dataset
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(example, names_to_types)
    article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]]
    groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]]

    if FLAGS.upper_bound:
        # If upper bound, then get the groundtruth singletons/pairs
        replaced_ssi_list = util.replace_empty_ssis(enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(replaced_ssi_list)
        summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        # Generates summary based on BERT output. This is an extractive summary.
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx)
        similar_source_indices_list_trunc = similar_source_indices_list[:ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:ssi_length_extractive]
        if example_idx <= 1:
            summary_sent_tokens = [sent.split(' ') for sent in summary_sents_for_html_trunc]
            extracted_sents_in_article_html = html_highlight_sents_in_article(summary_sent_tokens, similar_source_indices_list_trunc,
                                            article_sent_tokens, doc_indices=doc_indices)

            groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(
                                            groundtruth_summ_sent_tokens,
                                           article_sent_tokens, None, sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(groundtruth_summ_sent_tokens, groundtruth_ssi_list,
                                            article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices)

            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            ssi_functions.write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
Esempio n. 13
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    source_dir = os.path.join(data_dir, FLAGS.dataset_name)
    source_files = sorted(glob.glob(source_dir + '/' + FLAGS.dataset_split + '*'))

    total = len(source_files) * 1000 if ('cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files)
    example_generator = data.example_generator(source_dir + '/' + FLAGS.dataset_split + '*', True, False,
                                               should_check_valid=False)

    for example_idx, example in enumerate(tqdm(example_generator, total=total)):
        raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
            example, names_to_types)
        article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
        groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]]
        if doc_indices is None:
            doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens))
        doc_indices = [int(doc_idx) for doc_idx in doc_indices]
        rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens)
        groundtruth_similar_source_indices_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, FLAGS.sentence_limit)
Esempio n. 14
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.singles_and_pairs == 'singles':
        FLAGS.sentence_limit = 1
    else:
        FLAGS.sentence_limit = 2

    if FLAGS.dataset_name == 'all':
        dataset_names = ['cnn_dm', 'xsum', 'duc_2004']
    else:
        dataset_names = [FLAGS.dataset_name]

    for dataset_name in dataset_names:
        FLAGS.dataset_name = dataset_name

        source_dir = os.path.join(data_dir, dataset_name)

        if FLAGS.dataset_split == 'all':
            if dataset_name == 'duc_2004':
                dataset_splits = ['test']
            else:
                # dataset_splits = ['val_test', 'test', 'val', 'train']
                dataset_splits = ['test', 'val', 'train']
        else:
            dataset_splits = [FLAGS.dataset_split]

        for dataset_split in dataset_splits:
            if dataset_split == 'val_test':
                source_dataset_split = 'val'
            else:
                source_dataset_split = dataset_split

            source_files = sorted(
                glob.glob(source_dir + '/' + source_dataset_split + '*'))

            total = len(source_files) * 1000
            example_generator = data.example_generator(
                source_dir + '/' + source_dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            out_dir = os.path.join('data', 'bert', dataset_name,
                                   FLAGS.singles_and_pairs, 'input')
            util.create_dirs(out_dir)

            writer = open(os.path.join(out_dir, dataset_split) + '.tsv', 'wb')
            header_list = [
                'should_merge', 'sent1', 'sent2', 'example_idx', 'inst_id',
                'ssi'
            ]
            writer.write(('\t'.join(header_list) + '\n').encode())
            inst_id = 0
            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent, whitespace=True)
                    for sent in raw_article_sents
                ]
                groundtruth_summ_sents = [[
                    sent.strip()
                    for sent in groundtruth_summary_text.strip().split('\n')
                ]]
                if dataset_name != 'duc_2004' or doc_indices is None or (
                        dataset_name != 'duc_2004' and len(doc_indices) != len(
                            util.flatten_list_of_lists(article_sent_tokens))):
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                rel_sent_indices, _, _ = ssi_functions.get_rel_sent_indices(
                    doc_indices, article_sent_tokens)
                similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                possible_pairs = [
                    x for x in list(
                        itertools.combinations(
                            list(range(len(raw_article_sents))), 2))
                ]  # all pairs
                possible_pairs = filter_pairs_by_sent_position(
                    possible_pairs, rel_sent_indices=rel_sent_indices)
                possible_singles = [(i, )
                                    for i in range(len(raw_article_sents))]
                positives = [ssi for ssi in similar_source_indices_list]

                if dataset_split == 'test' or dataset_split == 'val_test':
                    if FLAGS.singles_and_pairs == 'singles':
                        possible_combinations = possible_singles
                    else:
                        possible_combinations = possible_pairs + possible_singles
                    negatives = [
                        ssi for ssi in possible_combinations
                        if not (ssi in positives or ssi[::-1] in positives)
                    ]

                    for ssi_idx, ssi in enumerate(positives):
                        if len(ssi) == 0:
                            continue
                        if chronological_ssi and len(ssi) >= 2:
                            if ssi[0] > ssi[1]:
                                ssi = (min(ssi), max(ssi))
                        writer.write(
                            get_string_bert_example(raw_article_sents, ssi, 1,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1
                    for ssi in negatives:
                        writer.write(
                            get_string_bert_example(raw_article_sents, ssi, 0,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1

                else:
                    positive_sents = list(
                        set(util.flatten_list_of_lists(positives)))
                    negative_pairs = [
                        pair for pair in possible_pairs
                        if not any(i in positive_sents for i in pair)
                    ]
                    negative_singles = [
                        sing for sing in possible_singles
                        if not sing[0] in positive_sents
                    ]
                    random_negative_pairs = np.random.permutation(
                        len(negative_pairs)).tolist()
                    random_negative_singles = np.random.permutation(
                        len(negative_singles)).tolist()

                    for ssi in similar_source_indices_list:
                        if len(ssi) == 0:
                            continue
                        if chronological_ssi and len(ssi) >= 2:
                            if ssi[0] > ssi[1]:
                                ssi = (min(ssi), max(ssi))
                        is_pair = len(ssi) == 2
                        writer.write(
                            get_string_bert_example(raw_article_sents, ssi, 1,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1

                        # False sentence single/pair
                        if is_pair:
                            if len(random_negative_pairs) == 0:
                                continue
                            negative_indices = negative_pairs[
                                random_negative_pairs.pop()]
                        else:
                            if len(random_negative_singles) == 0:
                                continue
                            negative_indices = negative_singles[
                                random_negative_singles.pop()]
                        article_lcs_paths = None
                        writer.write(
                            get_string_bert_example(raw_article_sents,
                                                    negative_indices, 0,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1
Esempio n. 15
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.summarizer == 'all':
        summary_methods = list(summarizers.keys())
    else:
        summary_methods = [FLAGS.summarizer]
    if FLAGS.dataset_name == 'all':
        dataset_names = datasets
    else:
        dataset_names = [FLAGS.dataset_name]

    sheets_strs = []
    for summary_method in summary_methods:
        summary_fn = summarizers[summary_method]
        for dataset_name in dataset_names:
            FLAGS.dataset_name = dataset_name

            original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else ''
            vocab = Vocab('logs/vocab' + '_' + original_dataset_name,
                          50000)  # create a vocabulary

            source_dir = os.path.join(data_dir, dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + FLAGS.dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in dataset_name or 'newsroom' in dataset_name
                or 'xsum' in dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + FLAGS.dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            if dataset_name == 'duc_2004':
                abs_source_dir = os.path.join(
                    os.path.expanduser('~') + '/data/tf_data/with_coref',
                    dataset_name)
                abs_example_generator = data.example_generator(
                    abs_source_dir + '/' + FLAGS.dataset_split + '*',
                    True,
                    False,
                    should_check_valid=False)
                abs_names_to_types = [('abstract', 'string_list')]

            triplet_ssi_list = []
            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if dataset_name == 'duc_2004':
                    abs_example = next(abs_example_generator)
                    groundtruth_summary_texts = util.unpack_tf_example(
                        abs_example, abs_names_to_types)
                    groundtruth_summary_texts = groundtruth_summary_texts[0]
                    groundtruth_summ_sents_list = [[
                        sent.strip() for sent in data.abstract2sents(abstract)
                    ] for abstract in groundtruth_summary_texts]

                else:
                    groundtruth_summary_texts = [groundtruth_summary_text]
                    groundtruth_summ_sents_list = []
                    for groundtruth_summary_text in groundtruth_summary_texts:
                        groundtruth_summ_sents = [
                            sent.strip() for sent in
                            groundtruth_summary_text.strip().split('\n')
                        ]
                        groundtruth_summ_sents_list.append(
                            groundtruth_summ_sents)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                log_dir = os.path.join(log_root,
                                       dataset_name + '_' + summary_method)
                dec_dir = os.path.join(log_dir, 'decoded')
                ref_dir = os.path.join(log_dir, 'reference')
                util.create_dirs(dec_dir)
                util.create_dirs(ref_dir)

                parser = PlaintextParser.from_string(
                    ' '.join(raw_article_sents), Tokenizer("english"))
                summarizer = summary_fn()

                summary = summarizer(
                    parser.document,
                    5)  #Summarize the document with 5 sentences
                summary = [str(sentence) for sentence in summary]

                summary_tokenized = []
                for sent in summary:
                    summary_tokenized.append(sent.lower())

                rouge_functions.write_for_rouge(groundtruth_summ_sents_list,
                                                summary_tokenized,
                                                example_idx,
                                                ref_dir,
                                                dec_dir,
                                                log=False)

                decoded_sent_tokens = [
                    sent.split() for sent in summary_tokenized
                ]
                sentence_limit = 2
                sys_ssi_list, _, _ = get_simple_source_indices_list(
                    decoded_sent_tokens, article_sent_tokens, vocab,
                    sentence_limit, min_matched_tokens)
                triplet_ssi_list.append(
                    (groundtruth_similar_source_indices_list, sys_ssi_list,
                     -1))

            print('Evaluating Lambdamart model F1 score...')
            suffix = util.all_sent_selection_eval(triplet_ssi_list)
            print(suffix)

            results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir)
            print(("Results_dict: ", results_dict))
            sheets_str = rouge_functions.rouge_log(results_dict,
                                                   log_dir,
                                                   suffix=suffix)
            sheets_strs.append(dataset_name + '_' + summary_method + '\n' +
                               sheets_str)

    for sheets_str in sheets_strs:
        print(sheets_str + '\n')
def main(unused_argv):
    print('Running statistics on %s' % exp_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.singles_and_pairs == 'both':
        in_dataset = FLAGS.dataset_name
        out_dataset = FLAGS.dataset_name + '_both'
    else:
        in_dataset = FLAGS.dataset_name + '_singles'
        out_dataset = FLAGS.dataset_name + '_singles'

    if FLAGS.lr:
        out_dataset = FLAGS.dataset_name + '_lr'

    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, in_dataset)
    ex_sents = ['single .', 'sentence .']
    article_text = ' '.join(ex_sents)
    sent_term_matrix = util.get_doc_substituted_tfidf_matrix(
        tfidf_vectorizer, ex_sents, article_text, pca)
    if FLAGS.singles_and_pairs == 'pairs':
        single_feat_len = 0
    else:
        single_feat_len = len(
            get_single_sent_features(0, sent_term_matrix,
                                     [['single', '.'], ['sentence', '.']],
                                     [0, 0], 0))
    if FLAGS.singles_and_pairs == 'singles':
        pair_feat_len = 0
    else:
        pair_feat_len = len(
            get_pair_sent_features([0, 1], sent_term_matrix,
                                   [['single', '.'], ['sentence', '.']],
                                   [0, 0], [0, 0]))
    util.print_vars(single_feat_len, pair_feat_len)
    util.create_dirs(temp_dir)

    if FLAGS.dataset_split == 'all':
        dataset_splits = ['test', 'val', 'train']
    elif FLAGS.dataset_split == 'train_val':
        dataset_splits = ['val', 'train']
    else:
        dataset_splits = [FLAGS.dataset_split]
    for split in dataset_splits:
        source_files = sorted(glob.glob(source_dir + '/' + split + '*'))

        out_path = os.path.join(out_dir, out_dataset, split)
        if FLAGS.pca:
            out_path += '_pca'
        util.create_dirs(os.path.join(out_path))
        total = len(source_files) * 1000 if (
            'cnn' in in_dataset or 'newsroom' in in_dataset
            or 'xsum' in in_dataset) else len(source_files)
        example_generator = data.example_generator(source_dir + '/' + split +
                                                   '*',
                                                   True,
                                                   False,
                                                   should_check_valid=False)
        # for example in tqdm(example_generator, total=total):
        ex_gen = example_generator_extended(example_generator, total,
                                            single_feat_len, pair_feat_len,
                                            FLAGS.singles_and_pairs, out_path)
        print('Creating list')
        ex_list = [ex for ex in ex_gen]
        if FLAGS.num_instances != -1:
            ex_list = ex_list[:FLAGS.num_instances]
        print('Converting...')
        # all_features = pool.map(convert_article_to_lambdamart_features, ex_list)

        # all_features = ray.get([convert_article_to_lambdamart_features.remote(ex) for ex in ex_list])

        if FLAGS.lr:
            all_instances = list(
                futures.map(convert_article_to_lambdamart_features, ex_list))
            all_instances = util.flatten_list_of_lists(all_instances)
            x = [inst.features for inst in all_instances]
            x = np.array(x)
            y = [inst.relevance for inst in all_instances]
            y = np.expand_dims(np.array(y), 1)
            x_y = np.concatenate((x, y), 1)
            np.save(writer, x_y)
        else:
            list(futures.map(convert_article_to_lambdamart_features, ex_list))
            # writer.write(''.join(all_features))

        # all_features = []
        # for example  in tqdm(ex_gen, total=total):
        #     all_features.append(convert_article_to_lambdamart_features(example))

        # all_features = util.flatten_list_of_lists(all_features)
        # num1 = sum(x == 1 for x in all_features)
        # num2 = sum(x == 2 for x in all_features)
        # print 'Single sent: %d instances. Pair sent: %d instances.' % (num1, num2)

        # for example in tqdm(ex_gen, total=total):
        #     features = convert_article_to_lambdamart_features(example)
        #     writer.write(features)

        final_out_path = out_path + '.txt'
        file_names = sorted(glob.glob(os.path.join(out_path, '*')))
        writer = open(final_out_path, 'wb')
        for file_name in tqdm(file_names):
            with open(file_name) as f:
                text = f.read()
            writer.write(text)
        writer.close()
    util.print_execution_time(start_time)
def convert_article_to_lambdamart_features(ex):
    # example_idx += 1
    # if num_instances != -1 and example_idx >= num_instances:
    #     break
    example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs, out_path = ex
    print(example_idx)
    raw_article_sents, similar_source_indices_list, summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    if doc_indices is None:
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    doc_indices = [int(doc_idx) for doc_idx in doc_indices]
    if len(doc_indices) != len(
            util.flatten_list_of_lists(article_sent_tokens)):
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    rel_sent_indices, _, _ = util.get_rel_sent_indices(doc_indices,
                                                       article_sent_tokens)
    if FLAGS.singles_and_pairs == 'singles':
        sentence_limit = 1
    else:
        sentence_limit = 2
    similar_source_indices_list = util.enforce_sentence_limit(
        similar_source_indices_list, sentence_limit)
    summ_sent_tokens = [
        sent.strip().split() for sent in summary_text.strip().split('\n')
    ]

    # sent_term_matrix = util.get_tfidf_matrix(raw_article_sents)
    article_text = ' '.join(raw_article_sents)
    sent_term_matrix = util.get_doc_substituted_tfidf_matrix(
        tfidf_vectorizer, raw_article_sents, article_text, pca)
    doc_vector = np.mean(sent_term_matrix, axis=0)

    out_str = ''
    # ssi_idx_cur_inst_id = defaultdict(int)
    instances = []

    if importance:
        importances = util.special_squash(
            util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents,
                                       pca))
        possible_pairs = [
            x for x in list(
                itertools.combinations(list(range(len(raw_article_sents))), 2))
        ]  # all pairs
        if FLAGS.use_pair_criteria:
            possible_pairs = filter_pairs_by_criteria(raw_article_sents,
                                                      possible_pairs, corefs)
        if FLAGS.sent_position_criteria:
            possible_pairs = filter_pairs_by_sent_position(
                possible_pairs, rel_sent_indices)
        possible_singles = [(i, ) for i in range(len(raw_article_sents))]
        possible_combinations = possible_pairs + possible_singles
        positives = [ssi for ssi in similar_source_indices_list]
        negatives = [
            ssi for ssi in possible_combinations
            if not (ssi in positives or ssi[::-1] in positives)
        ]

        negative_pairs = [
            x for x in possible_pairs
            if not (x in similar_source_indices_list
                    or x[::-1] in similar_source_indices_list)
        ]
        negative_singles = [
            x for x in possible_singles
            if not (x in similar_source_indices_list
                    or x[::-1] in similar_source_indices_list)
        ]
        random_negative_pairs = np.random.permutation(
            len(negative_pairs)).tolist()
        random_negative_singles = np.random.permutation(
            len(negative_singles)).tolist()

        qid = example_idx
        for similar_source_indices in positives:
            # True sentence single/pair
            relevance = 1
            features = get_features(similar_source_indices, sent_term_matrix,
                                    article_sent_tokens, rel_sent_indices,
                                    single_feat_len, pair_feat_len,
                                    importances, singles_and_pairs)
            if features is None:
                continue
            instances.append(
                Lambdamart_Instance(features, relevance, qid,
                                    similar_source_indices))
            a = 0

            if FLAGS.dataset_name == 'xsum' and FLAGS.special_xsum_balance:
                neg_relevance = 0
                num_negative = 4
                if FLAGS.singles_and_pairs == 'singles':
                    num_neg_singles = num_negative
                    num_neg_pairs = 0
                else:
                    num_neg_singles = num_negative / 2
                    num_neg_pairs = num_negative / 2
                for _ in range(num_neg_singles):
                    if len(random_negative_singles) == 0:
                        continue
                    negative_indices = negative_singles[
                        random_negative_singles.pop()]
                    neg_features = get_features(negative_indices,
                                                sent_term_matrix,
                                                article_sent_tokens,
                                                rel_sent_indices,
                                                single_feat_len, pair_feat_len,
                                                importances, singles_and_pairs)
                    if neg_features is None:
                        continue
                    instances.append(
                        Lambdamart_Instance(neg_features, neg_relevance, qid,
                                            negative_indices))
                for _ in range(num_neg_pairs):
                    if len(random_negative_pairs) == 0:
                        continue
                    negative_indices = negative_pairs[
                        random_negative_pairs.pop()]
                    neg_features = get_features(negative_indices,
                                                sent_term_matrix,
                                                article_sent_tokens,
                                                rel_sent_indices,
                                                single_feat_len, pair_feat_len,
                                                importances, singles_and_pairs)
                    if neg_features is None:
                        continue
                    instances.append(
                        Lambdamart_Instance(neg_features, neg_relevance, qid,
                                            negative_indices))
            elif balance:
                # False sentence single/pair
                is_pair = len(similar_source_indices) == 2
                if is_pair:
                    if len(random_negative_pairs) == 0:
                        continue
                    negative_indices = negative_pairs[
                        random_negative_pairs.pop()]
                else:
                    if len(random_negative_singles) == 0:
                        continue
                    negative_indices = negative_singles[
                        random_negative_singles.pop()]
                neg_relevance = 0
                neg_features = get_features(negative_indices, sent_term_matrix,
                                            article_sent_tokens,
                                            rel_sent_indices, single_feat_len,
                                            pair_feat_len, importances,
                                            singles_and_pairs)
                if neg_features is None:
                    continue
                instances.append(
                    Lambdamart_Instance(neg_features, neg_relevance, qid,
                                        negative_indices))
        if not balance:
            for negative_indices in negatives:
                neg_relevance = 0
                neg_features = get_features(negative_indices, sent_term_matrix,
                                            article_sent_tokens,
                                            single_feat_len, pair_feat_len,
                                            importances, singles_and_pairs)
                if neg_features is None:
                    continue
                instances.append(
                    Lambdamart_Instance(neg_features, neg_relevance, qid,
                                        negative_indices))

    sorted_instances = sorted(instances,
                              key=lambda x: (x.qid, x.source_indices))
    assign_inst_ids(sorted_instances)
    if FLAGS.lr:
        return sorted_instances
    else:
        for instance in sorted_instances:
            lambdamart_str = format_to_lambdamart(instance, single_feat_len)
            out_str += lambdamart_str + '\n'
        with open(os.path.join(out_path, '%06d.txt' % example_idx), 'wb') as f:
            f.write(out_str)
Esempio n. 18
0
    def __init__(self, article, abstract_sentences, all_abstract_sentences,
                 doc_indices, raw_article_sents, vocab, hps):
        """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.

        Args:
            article: source text; a string. each token is separated by a single space.
            abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space.
            vocab: Vocabulary object
            hps: hyperparameters
        """
        self.hps = hps

        # Get ids of special tokens
        start_decoding = vocab.word2id(data.START_DECODING)
        stop_decoding = vocab.word2id(data.STOP_DECODING)

        # Process the article
        article_words = article.split()
        if len(article_words) > hps.max_enc_steps:
            article_words = article_words[:hps.max_enc_steps]
        self.enc_input = [
            vocab.word2id(w) for w in article_words
        ]  # list of word ids; OOVs are represented by the id for UNK token

        # Process the abstract
        abstract = ' '.join(abstract_sentences)  # string
        abstract_words = abstract.split()  # list of strings
        abs_ids = [
            vocab.word2id(w) for w in abstract_words
        ]  # list of word ids; OOVs are represented by the id for UNK token

        # Get the decoder input sequence and target sequence
        self.dec_input, self.target = self.get_dec_inp_targ_seqs(
            abs_ids, hps.max_dec_steps, start_decoding, stop_decoding)
        self.dec_len = len(self.dec_input)

        # If using pointer-generator mode, we need to store some extra info
        if hps.pointer_gen:

            if raw_article_sents is not None and len(raw_article_sents) > 0:
                self.tokenized_sents = [
                    process_sent(sent) for sent in raw_article_sents
                ]
                self.word_ids_sents, self.article_oovs = data.tokenizedarticle2ids(
                    self.tokenized_sents, vocab)
                self.enc_input_extend_vocab = util.flatten_list_of_lists(
                    self.word_ids_sents)
                self.enc_len = len(
                    self.enc_input_extend_vocab
                )  # store the length after truncation but before padding
            else:
                # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
                article_str = util.to_unicode(article)
                raw_article_sents = nltk.tokenize.sent_tokenize(article_str)
                self.tokenized_sents = [
                    process_sent(sent) for sent in raw_article_sents
                ]
                self.word_ids_sents, self.article_oovs = data.tokenizedarticle2ids(
                    self.tokenized_sents, vocab)
                self.enc_input_extend_vocab = util.flatten_list_of_lists(
                    self.word_ids_sents)
                # self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
                self.enc_len = len(
                    self.enc_input_extend_vocab
                )  # store the length after truncation but before padding

            # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
            abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab,
                                                     self.article_oovs)

            # Overwrite decoder target sequence so it uses the temp article OOV ids
            _, self.target = self.get_dec_inp_targ_seqs(
                abs_ids_extend_vocab, hps.max_dec_steps, start_decoding,
                stop_decoding)

        # Store the original strings
        self.original_article = article
        self.raw_article_sents = raw_article_sents
        self.original_abstract = abstract
        self.original_abstract_sents = abstract_sentences
        self.all_original_abstract_sents = all_abstract_sentences

        self.doc_indices = doc_indices  # doc_id in multidoc correspond to each word
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, _, _ = ex
    print(example_idx)
    # example_idx += 1
    qid = example_idx
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    if FLAGS.dataset_name == 'duc_2004':
        groundtruth_summ_sents = [[
            sent.strip() for sent in gt_summ_text.strip().split('\n')
        ] for gt_summ_text in groundtruth_summary_text]
    else:
        groundtruth_summ_sents = [[
            sent.strip()
            for sent in groundtruth_summary_text.strip().split('\n')
        ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]

    if FLAGS.upper_bound:
        replaced_ssi_list = util.replace_empty_ssis(
            enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(
            replaced_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    elif FLAGS.lead:
        lead_ssi_list = [(idx, ) for idx in list(
            range(util.average_sents_for_dataset[FLAGS.dataset_name]))]
        lead_ssi_list = lead_ssi_list[:len(
            raw_article_sents
        )]  # make sure the sentence indices don't go past the total number of sentences in the article
        selected_article_sent_indices = util.flatten_list_of_lists(
            lead_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = lead_ssi_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(
            article_sent_tokens, qid_ssi_to_importances, example_idx)
        similar_source_indices_list_trunc = similar_source_indices_list[:
                                                                        ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:
                                                              ssi_length_extractive]
        if example_idx <= 100:
            summary_sent_tokens = [
                sent.split(' ') for sent in summary_sents_for_html_trunc
            ]
            extracted_sents_in_article_html = html_highlight_sents_in_article(
                summary_sent_tokens,
                similar_source_indices_list_trunc,
                article_sent_tokens,
                doc_indices=doc_indices)
            # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx)

            groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(
                groundtruth_summ_sent_tokens, article_sent_tokens, None,
                sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(
                groundtruth_summ_sent_tokens,
                groundtruth_ssi_list,
                article_sent_tokens,
                lcs_paths_list=lcs_paths_list,
                article_lcs_paths_list=article_lcs_paths_list,
                doc_indices=doc_indices)
            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents,
                                    example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list,
            similar_source_indices_list, ssi_length_extractive)
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings = ex
    print(example_idx)
    # example_idx += 1
    qid = example_idx
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sent_tokens = []
    groundtruth_summ_sents = [[
        sent.strip() for sent in groundtruth_summary_text.strip().split('\n')
    ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]

    if FLAGS.upper_bound:
        replaced_ssi_list = util.replace_empty_ssis(
            enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(
            replaced_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, \
            article_lcs_paths_list, token_probs_list = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx, qid_ssi_to_token_scores_and_mappings)
        similar_source_indices_list_trunc = similar_source_indices_list[:
                                                                        ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:
                                                              ssi_length_extractive]
        if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100):
            summary_sent_tokens = [
                sent.split(' ') for sent in summary_sents_for_html_trunc
            ]
            if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0:
                lcs_paths_list_param = copy.deepcopy(article_lcs_paths_list)
            else:
                lcs_paths_list_param = None
            extracted_sents_in_article_html = html_highlight_sents_in_article(
                summary_sent_tokens,
                similar_source_indices_list_trunc,
                article_sent_tokens,
                doc_indices=doc_indices,
                lcs_paths_list=lcs_paths_list_param)
            # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx)

            groundtruth_ssi_list, gt_lcs_paths_list, gt_article_lcs_paths_list, gt_smooth_article_paths_list = get_simple_source_indices_list(
                groundtruth_summ_sent_tokens, article_sent_tokens, None,
                sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(
                groundtruth_summ_sent_tokens,
                groundtruth_ssi_list,
                article_sent_tokens,
                lcs_paths_list=gt_lcs_paths_list,
                article_lcs_paths_list=gt_smooth_article_paths_list,
                doc_indices=doc_indices)

            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            # all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html
            write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents,
                                    example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list,
            similar_source_indices_list, ssi_length_extractive,
            token_probs_list)
Esempio n. 21
0
    def __init__(self, article, abstract_sentences, all_abstract_sentences,
                 doc_indices, raw_article_sents, ssi, article_lcs_paths_list,
                 vocab, hps):
        """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.

        Args:
            article: source text; a string. each token is separated by a single space.
            abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space.
            vocab: Vocabulary object
            hps: hyperparameters
        """
        self.hps = hps

        # Get ids of special tokens
        start_decoding = vocab.word2id(data.START_DECODING)
        stop_decoding = vocab.word2id(data.STOP_DECODING)

        # # Process the article
        # article_words = article.split()
        # if len(article_words) > hps.max_enc_steps:
        #     article_words = article_words[:hps.max_enc_steps]
        # self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token

        # Process the abstract
        abstract = ' '.join(abstract_sentences)  # string
        abstract_words = abstract.split()  # list of strings
        abs_ids = [
            vocab.word2id(w) for w in abstract_words
        ]  # list of word ids; OOVs are represented by the id for UNK token

        # Get the decoder input sequence and target sequence
        self.dec_input, self.target = self.get_dec_inp_targ_seqs(
            abs_ids, hps.max_dec_steps, start_decoding, stop_decoding)
        self.dec_len = len(self.dec_input)

        # If using pointer-generator mode, we need to store some extra info
        if hps.pointer_gen:

            if raw_article_sents is not None and len(raw_article_sents) > 0:
                # self.tokenized_sents = [util.process_sent(sent) for sent in raw_article_sents]
                self.tokenized_sents = [
                    util.process_sent(sent, whitespace=True)
                    for sent in raw_article_sents
                ]
                if self.hps.sep:
                    for sent in self.tokenized_sents[:-1]:
                        sent.append(data.SEP_TOKEN)

                # Process the article
                article_words = util.flatten_list_of_lists(
                    self.tokenized_sents)
                if len(article_words) > hps.max_enc_steps:
                    article_words = article_words[:hps.max_enc_steps]
                self.enc_input = [
                    vocab.word2id(w) for w in article_words
                ]  # list of word ids; OOVs are represented by the id for UNK token

                if len(all_abstract_sentences) == 1:
                    doc_indices = [0] * len(article_words)

                self.word_ids_sents, self.article_oovs = data.tokenizedarticle2ids(
                    self.tokenized_sents, vocab)
                self.enc_input_extend_vocab = util.flatten_list_of_lists(
                    self.word_ids_sents)
                if len(self.enc_input_extend_vocab) > hps.max_enc_steps:
                    self.enc_input_extend_vocab = self.enc_input_extend_vocab[:
                                                                              hps
                                                                              .
                                                                              max_enc_steps]
                self.enc_len = len(
                    self.enc_input_extend_vocab
                )  # store the length after truncation but before padding
            else:
                # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
                article_str = util.to_unicode(article)
                raw_article_sents = nltk.tokenize.sent_tokenize(article_str)
                self.tokenized_sents = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]

                # Process the article
                article_words = util.flatten_list_of_lists(
                    self.tokenized_sents)
                if len(article_words) > hps.max_enc_steps:
                    article_words = article_words[:hps.max_enc_steps]
                self.enc_input = [
                    vocab.word2id(w) for w in article_words
                ]  # list of word ids; OOVs are represented by the id for UNK token

                if len(all_abstract_sentences) == 1:
                    doc_indices = [0] * len(article_words)

                self.word_ids_sents, self.article_oovs = data.tokenizedarticle2ids(
                    self.tokenized_sents, vocab)
                self.enc_input_extend_vocab = util.flatten_list_of_lists(
                    self.word_ids_sents)
                # self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
                if len(self.enc_input_extend_vocab) > hps.max_enc_steps:
                    self.enc_input_extend_vocab = self.enc_input_extend_vocab[:
                                                                              hps
                                                                              .
                                                                              max_enc_steps]
                self.enc_len = len(
                    self.enc_input_extend_vocab
                )  # store the length after truncation but before padding

            if self.hps.word_imp_reg:
                self.enc_importances = self.get_enc_importances(
                    self.tokenized_sents, abstract_words)

            # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
            abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab,
                                                     self.article_oovs)

            # Overwrite decoder target sequence so it uses the temp article OOV ids
            _, self.target = self.get_dec_inp_targ_seqs(
                abs_ids_extend_vocab, hps.max_dec_steps, start_decoding,
                stop_decoding)

        if ssi is not None:
            # Translate the similar source indices into masks over the encoder input
            self.ssi_masks = []
            for source_indices in ssi:
                ssi_sent_mask = [0.] * len(raw_article_sents)
                for source_idx in source_indices:
                    if source_idx >= len(ssi_sent_mask):
                        a = 0
                    ssi_sent_mask[source_idx] = 1.
                ssi_mask = pg_mmr_functions.convert_to_word_level(
                    ssi_sent_mask, self.tokenized_sents)
                self.ssi_masks.append(ssi_mask)

            summary_sent_tokens = [
                sent.strip().split() for sent in abstract_sentences
            ]
            if self.hps.ssi_data_path is None and len(
                    self.ssi_masks) != len(summary_sent_tokens):
                raise Exception(
                    'len(self.ssi_masks) != len(summary_sent_tokens)')

            self.sent_indices = pg_mmr_functions.convert_to_word_level(
                list(range(len(summary_sent_tokens))),
                summary_sent_tokens).tolist()

        if article_lcs_paths_list is not None:
            if len(article_lcs_paths_list) > 1:
                raise Exception('Need to implement for non-sent_dataset')
            article_lcs_paths = article_lcs_paths_list[0]
            imp_mask = [0] * len(article_words)
            to_add = 0
            for source_idx, word_indices_list in enumerate(article_lcs_paths):
                if source_idx > 0:
                    to_add += len(self.tokenized_sents[source_idx - 1])
                for word_idx in word_indices_list:
                    if word_idx + to_add >= len(imp_mask):
                        if len(imp_mask) == hps.max_enc_steps:
                            continue
                        else:
                            print(self.tokenized_sents, article_lcs_paths)
                            raise Exception(
                                'word_idx + to_add (%d) is larger than imp_mask size (%d)'
                                % (word_idx + to_add, len(imp_mask)))
                    imp_mask[word_idx + to_add] = 1
            self.importance_mask = imp_mask

        # Store the original strings
        self.original_article = article
        self.raw_article_sents = raw_article_sents
        self.original_abstract = abstract
        self.original_abstract_sents = abstract_sentences
        self.all_original_abstract_sents = all_abstract_sentences

        self.doc_indices = doc_indices
        self.ssi = ssi
        self.article_lcs_paths_list = article_lcs_paths_list
def generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx,
                     qid_ssi_to_token_scores_and_mappings):
    qid = example_idx

    summary_sent_tokens = []
    summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
    already_used_source_indices = []
    similar_source_indices_list = []
    summary_sents_for_html = []
    article_lcs_paths_list = []
    token_probs_list = []
    ssi_length_extractive = None
    while len(summary_tokens) < 300:
        if len(summary_tokens) >= l_param and ssi_length_extractive is None:
            ssi_length_extractive = len(similar_source_indices_list)
        # if FLAGS.dataset_name == 'xsum' and len(summary_tokens) > 0:
        #     ssi_length_extractive = len(similar_source_indices_list)
        #     break
        if FLAGS.use_mmr:
            score_dict = util.calc_MMR_source_indices(article_sent_tokens,
                                                      summary_tokens,
                                                      None,
                                                      qid_ssi_to_importances,
                                                      qid=qid)
        else:
            score_dict = qid_ssi_to_importances[qid]
        sents, source_indices = get_best_source_sents(
            article_sent_tokens, score_dict, already_used_source_indices)
        if len(source_indices) == 0:
            break

        token_scores, token_mappings = get_token_info_for_ssi(
            qid_ssi_to_token_scores_and_mappings, qid, source_indices)
        # if np.max(token_mappings) !=
        token_cons_scores = consolidate_token_scores(token_scores,
                                                     token_mappings)
        if len(token_cons_scores) != len(sents):
            print(token_cons_scores, sents)
            raise Exception('Len of token_cons_scores %d != Len of sents %d' %
                            (len(token_cons_scores), len(sents)))
        padded_token_cons_scores = [
        ]  # we need to pad it, because sometimes the instance was too long for BERT, so it got truncated. So we need to fill the end of the sentences with 0 probabilities.
        for sent_idx, sent_scores in enumerate(token_cons_scores):
            sent = sents[sent_idx]
            if len(sent_scores) > len(sent):
                print(token_cons_scores, sents)
                raise Exception('Len of sent_scores %d > Len of sent %d' %
                                (len(sent_scores), len(sent)))
            while len(sent_scores) < len(sent):
                sent_scores.append(0.)
            padded_token_cons_scores.append(sent_scores)
        token_probs_list.append(padded_token_cons_scores)
        token_tags = threshold_token_scores(
            padded_token_cons_scores, FLAGS.tag_threshold
        )  # shape (1 or 2, len(sent)) 1 or 2 depending on if it is singleton/pair
        article_lcs_paths = ssi_functions.binary_tags_to_list(token_tags)
        article_lcs_paths_list.append(article_lcs_paths)

        # if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0:
        #     sents_only_tagged = filter_untagged(sents, token_tags)
        #     summary_sent_tokens.extend(sents_only_tagged)
        # else:
        summary_sent_tokens.extend(sents)

        summary_tokens = util.flatten_list_of_lists(summary_sent_tokens)
        similar_source_indices_list.append(source_indices)
        summary_sents_for_html.append(' <br> '.join(
            [' '.join(sent) for sent in sents]))
        if filter_sentences:
            already_used_source_indices.extend(source_indices)
    if ssi_length_extractive is None:
        ssi_length_extractive = len(similar_source_indices_list)
    selected_article_sent_indices = util.flatten_list_of_lists(
        similar_source_indices_list[:ssi_length_extractive])
    summary_sents = [
        ' '.join(sent) for sent in util.reorder(article_sent_tokens,
                                                selected_article_sent_indices)
    ]
    # summary = '\n'.join([' '.join(tokens) for tokens in summary_sent_tokens])
    return summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, article_lcs_paths_list, token_probs_list
def main(unused_argv):
    # def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    print('Running statistics on %s' % exp_name)

    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))
    ex_sents = ['single .', 'sentence .']
    article_text = ' '.join(ex_sents)

    total = len(source_files) * 1000
    example_generator = data.example_generator(source_dir + '/' +
                                               dataset_split + '*',
                                               True,
                                               False,
                                               should_check_valid=False)

    qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path)
    qid_ssi_to_token_scores_and_mappings = get_token_scores_for_ssi(
        temp_in_path, file_path_seq, file_path_mappings)
    ex_gen = example_generator_extended(example_generator, total,
                                        qid_ssi_to_importances,
                                        qid_ssi_to_token_scores_and_mappings)
    print('Creating list')
    ex_list = [ex for ex in ex_gen]
    ssi_list = list(futures.map(evaluate_example, ex_list))

    # save ssi_list
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'wb') as f:
        pickle.dump(ssi_list, f)
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
        ssi_list = pickle.load(f)
    print('Evaluating BERT model F1 score...')
    suffix = util.all_sent_selection_eval(ssi_list)
    #
    # # for ex in tqdm(ex_list, total=total):
    # #     load_and_evaluate_example(ex)
    #
    print('Evaluating ROUGE...')
    results_dict = rouge_functions.rouge_eval(ref_dir,
                                              dec_dir,
                                              l_param=l_param)
    # print("Results_dict: ", results_dict)
    rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix)

    ssis_restricted = [
        ssi_triple[1][:ssi_triple[2]] for ssi_triple in ssi_list
    ]
    ssi_lens = [
        len(source_indices)
        for source_indices in util.flatten_list_of_lists(ssis_restricted)
    ]
    # print ssi_lens
    num_singles = ssi_lens.count(1)
    num_pairs = ssi_lens.count(2)
    print(
        'Percent singles/pairs: %.2f %.2f' %
        (num_singles * 100. / len(ssi_lens), num_pairs * 100. / len(ssi_lens)))

    util.print_execution_time(start_time)
Esempio n. 24
0
def process_attn_selections(attn_dir, decode_dir, vocab, extraction_eval=False):

    html_dir = os.path.join(decode_dir, 'extr_vis')
    util.create_dirs(html_dir)
    file_names = sorted(glob.glob(os.path.join(attn_dir, '*')))

    if extraction_eval:
        ssi_dir = os.path.join('data/ssi', FLAGS.dataset_name, 'test_ssi.pkl')
        with open(ssi_dir) as f:
            ssi_list = pickle.load(f)
        if len(ssi_list) != len(file_names):
            raise Exception('len of ssi_list does not equal len file_names: ', len(ssi_list), len(file_names))
    triplet_ssi_list = []
    for file_idx, file_name in enumerate(tqdm(file_names)):
        with open(file_name) as f:
            data = json.load(f)
        p_gens = util.flatten_list_of_lists(data['p_gens'])
        article_lst = data['article_lst']
        abstract_lst = data['abstract_str'].strip().split()
        decoded_lst = data['decoded_lst']
        attn_dists = np.array(data['attn_dists'])

        article_lst = [art_word.replace('__', '') for art_word in article_lst]
        decoded_lst = [dec_word.replace('__', '') for dec_word in decoded_lst]
        abstract_lst = [abs_word.replace('__', '') for abs_word in abstract_lst]

        min_matched_tokens = 2
        if 'singles' in FLAGS.exp_name:
            sentence_limit = 1
        else:
            sentence_limit = 2
        summary_sent_tokens = [nltk.tokenize.word_tokenize(sent) for sent in nltk.tokenize.sent_tokenize(' '.join(abstract_lst))]
        decoded_sent_tokens = [nltk.tokenize.word_tokenize(sent) for sent in nltk.tokenize.sent_tokenize(' '.join(decoded_lst))]
        article_sent_tokens = [nltk.tokenize.word_tokenize(sent) for sent in nltk.tokenize.sent_tokenize(' '.join(article_lst))]
        gt_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(summary_sent_tokens, article_sent_tokens, vocab, sentence_limit,
                                       min_matched_tokens)
        sys_ssi_list, _, _ = get_simple_source_indices_list(decoded_sent_tokens, article_sent_tokens, vocab, sentence_limit,
                                       min_matched_tokens)


        match_indices = []
        for dec_idx, dec in enumerate(decoded_lst):
            art_match_indices = [art_idx for art_idx, art_word in enumerate(article_lst) if art_word.replace('__', '') == dec or art_word == dec]
            if len(art_match_indices) == 0:
                match_indices.append(None)
            else:
                art_attns = [attn_dists[dec_idx, art_idx] for art_idx in art_match_indices]
                best_match_idx = art_match_indices[np.argmax(art_attns)]
                match_indices.append(best_match_idx)

        html = create_html(article_lst, match_indices, decoded_lst, [abstract_lst], file_idx, gt_ssi_list, lcs_paths_list, article_lcs_paths_list, summary_sent_tokens, article_sent_tokens)
        with open(os.path.join(html_dir, '%06d.html' % file_idx), 'wb') as f:
            f.write(html)

        if extraction_eval:
            triplet_ssi_list.append((ssi_list[file_idx], sys_ssi_list, -1))

    if extraction_eval:
        print('Evaluating Lambdamart model F1 score...')
        suffix = util.all_sent_selection_eval(triplet_ssi_list)
        print(suffix)
        with open(os.path.join(decode_dir, 'extraction_results.txt'), 'wb') as f:
            f.write(suffix)


    a=0
Esempio n. 25
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    out_dir = os.path.join(
        os.path.expanduser('~') + '/data/kaiqiang_data', FLAGS.dataset_name)
    if FLAGS.mode == 'write':
        util.create_dirs(out_dir)
        if FLAGS.dataset_name == 'duc_2004':
            dataset_splits = ['test']
        elif FLAGS.dataset_split == 'all':
            dataset_splits = ['test', 'val', 'train']
        else:
            dataset_splits = [FLAGS.dataset_split]

        for dataset_split in dataset_splits:

            if dataset_split == 'test':
                ssi_data_path = os.path.join(
                    'logs/%s_bert_both_sentemb_artemb_plushidden' %
                    FLAGS.dataset_name, 'ssi.pkl')
                print(util.bcolors.OKGREEN +
                      "Loading SSI from BERT at %s" % ssi_data_path +
                      util.bcolors.ENDC)
                with open(ssi_data_path) as f:
                    ssi_triple_list = pickle.load(f)

            source_dir = os.path.join(data_dir, FLAGS.dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
                or 'xsum' in FLAGS.dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            out_document_path = os.path.join(out_dir,
                                             dataset_split + '.Ndocument')
            out_summary_path = os.path.join(out_dir,
                                            dataset_split + '.Nsummary')
            out_example_idx_path = os.path.join(out_dir,
                                                dataset_split + '.Nexampleidx')

            doc_writer = open(out_document_path, 'w')
            if dataset_split != 'test':
                sum_writer = open(out_summary_path, 'w')
            ex_idx_writer = open(out_example_idx_path, 'w')

            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances:
                    break
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if FLAGS.dataset_name == 'duc_2004':
                    groundtruth_summ_sents = [[
                        sent.strip()
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                else:
                    groundtruth_summ_sents = [[
                        sent.strip() for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens)

                if dataset_split == 'test':
                    if example_idx >= len(ssi_triple_list):
                        raise Exception(
                            'Len of ssi list (%d) is less than number of examples (>=%d)'
                            % (len(ssi_triple_list), example_idx))
                    ssi_length_extractive = ssi_triple_list[example_idx][2]
                    if ssi_length_extractive > 1:
                        a = 0
                    ssi = ssi_triple_list[example_idx][1]
                    ssi = ssi[:ssi_length_extractive]
                    groundtruth_similar_source_indices_list = ssi
                else:
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list,
                        FLAGS.sentence_limit)

                for ssi_idx, ssi in enumerate(
                        groundtruth_similar_source_indices_list):
                    if len(ssi) == 0:
                        continue
                    my_article = ' '.join(util.reorder(raw_article_sents, ssi))
                    doc_writer.write(my_article + '\n')
                    if dataset_split != 'test':
                        sum_writer.write(groundtruth_summ_sents[0][ssi_idx] +
                                         '\n')
                    ex_idx_writer.write(str(example_idx) + '\n')
    elif FLAGS.mode == 'evaluate':
        summary_dir = '/home/logan/data/kaiqiang_data/logan_ACL/trained_on_' + FLAGS.train_dataset + '/' + FLAGS.dataset_name
        out_summary_path = os.path.join(summary_dir, 'test' + 'Summary.txt')
        out_example_idx_path = os.path.join(out_dir, 'test' + '.Nexampleidx')
        decode_dir = 'logs/kaiqiang_%s_trainedon%s' % (FLAGS.dataset_name,
                                                       FLAGS.train_dataset)
        rouge_ref_dir = os.path.join(decode_dir, 'reference')
        rouge_dec_dir = os.path.join(decode_dir, 'decoded')
        util.create_dirs(rouge_ref_dir)
        util.create_dirs(rouge_dec_dir)

        def num_lines_in_file(file_path):
            with open(file_path) as f:
                num_lines = sum(1 for line in f)
            return num_lines

        def process_example(sents, ex_idx, groundtruth_summ_sents):
            final_decoded_words = []
            for sent in sents:
                final_decoded_words.extend(sent.split(' '))
            rouge_functions.write_for_rouge(groundtruth_summ_sents,
                                            None,
                                            ex_idx,
                                            rouge_ref_dir,
                                            rouge_dec_dir,
                                            decoded_words=final_decoded_words,
                                            log=False)

        num_lines_summary = num_lines_in_file(out_summary_path)
        num_lines_example_indices = num_lines_in_file(out_example_idx_path)
        if num_lines_summary != num_lines_example_indices:
            raise Exception(
                'Num lines summary != num lines example indices: (%d, %d)' %
                (num_lines_summary, num_lines_example_indices))

        source_dir = os.path.join(data_dir, FLAGS.dataset_name)
        example_generator = data.example_generator(source_dir + '/' + 'test' +
                                                   '*',
                                                   True,
                                                   False,
                                                   should_check_valid=False)

        sum_writer = open(out_summary_path)
        ex_idx_writer = open(out_example_idx_path)
        prev_ex_idx = 0
        sents = []

        for line_idx in tqdm(range(num_lines_summary)):
            line = sum_writer.readline()
            ex_idx = int(ex_idx_writer.readline())

            if ex_idx == prev_ex_idx:
                sents.append(line)
            else:
                example = example_generator.next()
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if FLAGS.dataset_name == 'duc_2004':
                    groundtruth_summ_sents = [[
                        sent.strip()
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                else:
                    groundtruth_summ_sents = [[
                        sent.strip() for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]]
                process_example(sents, ex_idx, groundtruth_summ_sents)
                prev_ex_idx = ex_idx
                sents = [line]

        example = example_generator.next()
        raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
            example, names_to_types)
        if FLAGS.dataset_name == 'duc_2004':
            groundtruth_summ_sents = [[
                sent.strip() for sent in gt_summ_text.strip().split('\n')
            ] for gt_summ_text in groundtruth_summary_text]
        else:
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
        process_example(sents, ex_idx, groundtruth_summ_sents)

        print("Now starting ROUGE eval...")
        if FLAGS.dataset_name == 'xsum':
            l_param = 100
        else:
            l_param = 100
        results_dict = rouge_functions.rouge_eval(rouge_ref_dir,
                                                  rouge_dec_dir,
                                                  l_param=l_param)
        rouge_functions.rouge_log(results_dict, decode_dir)

    else:
        raise Exception('mode flag was not evaluate or write.')
Esempio n. 26
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.dataset_name == 'all':
        dataset_names = ['cnn_dm', 'xsum', 'duc_2004']
    else:
        dataset_names = [FLAGS.dataset_name]

    if not os.path.exists(plot_data_file):
        all_lists_of_histogram_pairs = []
        for dataset_name in dataset_names:
            FLAGS.dataset_name = dataset_name

            if dataset_name == 'duc_2004':
                dataset_splits = ['test']
            elif FLAGS.dataset_split == 'all':
                dataset_splits = ['test', 'val', 'train']
            else:
                dataset_splits = [FLAGS.dataset_split]

            ssi_list = []
            for dataset_split in dataset_splits:

                ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name,
                                        dataset_split + '_ssi.pkl')

                with open(ssi_path) as f:
                    ssi_list.extend(pickle.load(f))

                if FLAGS.dataset_name == 'duc_2004':
                    for abstract_idx in [1, 2, 3]:
                        ssi_path = os.path.join(
                            ssi_dir, FLAGS.dataset_name, dataset_split +
                            '_ssi_' + str(abstract_idx) + '.pkl')
                        with open(ssi_path) as f:
                            temp_ssi_list = pickle.load(f)
                        ssi_list.extend(temp_ssi_list)

            ssi_2d = util.flatten_list_of_lists(ssi_list)

            num_extracted = [
                len(ssi) for ssi in util.flatten_list_of_lists(ssi_list)
            ]
            hist_num_extracted = np.histogram(num_extracted,
                                              bins=6,
                                              range=(0, 5))
            print(hist_num_extracted)
            print('Histogram of number of sentences merged: ' +
                  util.hist_as_pdf_str(hist_num_extracted))

            distances = [
                abs(ssi[0] - ssi[1]) for ssi in ssi_2d if len(ssi) >= 2
            ]
            print('Distance between sentences (mean, median): ',
                  np.mean(distances), np.median(distances))
            hist_dist = np.histogram(distances, bins=max(distances))
            print('Histogram of distances: ' + util.hist_as_pdf_str(hist_dist))

            summ_sent_idx_to_number_of_source_sents = [[], [], [], [], [], [],
                                                       [], [], [], []]
            for ssi in ssi_list:
                for summ_sent_idx, source_indices in enumerate(ssi):
                    if len(source_indices) == 0 or summ_sent_idx >= len(
                            summ_sent_idx_to_number_of_source_sents):
                        continue
                    num_sents = len(source_indices)
                    if num_sents > 2:
                        num_sents = 2
                    summ_sent_idx_to_number_of_source_sents[
                        summ_sent_idx].append(num_sents)
            print(
                "Number of source sents for summary sentence indices (Is the first summary sent more likely to match with a singleton or a pair?):"
            )
            for summ_sent_idx, list_of_numbers_of_source_sents in enumerate(
                    summ_sent_idx_to_number_of_source_sents):
                if len(list_of_numbers_of_source_sents) == 0:
                    percent_singleton = 0.
                else:
                    percent_singleton = list_of_numbers_of_source_sents.count(
                        1) * 1. / len(list_of_numbers_of_source_sents)
                    percent_pair = list_of_numbers_of_source_sents.count(
                        2) * 1. / len(list_of_numbers_of_source_sents)
                print str(percent_singleton) + '\t',
            print ''
            for summ_sent_idx, list_of_numbers_of_source_sents in enumerate(
                    summ_sent_idx_to_number_of_source_sents):
                if len(list_of_numbers_of_source_sents) == 0:
                    percent_pair = 0.
                else:
                    percent_singleton = list_of_numbers_of_source_sents.count(
                        1) * 1. / len(list_of_numbers_of_source_sents)
                    percent_pair = list_of_numbers_of_source_sents.count(
                        2) * 1. / len(list_of_numbers_of_source_sents)
                print str(percent_pair) + '\t',
            print ''

            primary_pos = [ssi[0] for ssi in ssi_2d if len(ssi) >= 1]
            secondary_pos = [ssi[1] for ssi in ssi_2d if len(ssi) >= 2]
            all_pos = [max(ssi) for ssi in ssi_2d if len(ssi) >= 1]

            # if FLAGS.dataset_name != 'duc_2004':
            #     plot_positions(primary_pos, secondary_pos, all_pos)

            if FLAGS.dataset_split == 'all':
                glob_string = '*.bin'
            else:
                glob_string = dataset_splits[0]

            print('Loading TFIDF vectorizer')
            with open(tfidf_vec_path, 'rb') as f:
                tfidf_vectorizer = pickle.load(f)

            source_dir = os.path.join(data_dir, FLAGS.dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + glob_string + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
                or 'xsum' in FLAGS.dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + glob_string + '*',
                True,
                False,
                should_check_valid=False)

            all_possible_singles = 0
            all_possible_pairs = [0]
            all_filtered_pairs = 0
            all_all_combinations = 0
            all_ssi_pairs = [0]
            ssi_pairs_with_shared_coref = [0]
            ssi_pairs_with_shared_word = [0]
            ssi_pairs_with_either_coref_or_word = [0]
            all_pairs_with_shared_coref = [0]
            all_pairs_with_shared_word = [0]
            all_pairs_with_either_coref_or_word = [0]
            actual_total = [0]
            rel_positions_primary = []
            rel_positions_secondary = []
            rel_positions_all = []
            sent_lens = []
            all_sent_lens = []
            all_pos = []
            y = []
            normalized_positions_primary = []
            normalized_positions_secondary = []
            all_normalized_positions_primary = []
            all_normalized_positions_secondary = []
            normalized_positions_singles = []
            normalized_positions_pairs_first = []
            normalized_positions_pairs_second = []
            primary_pos_duc = []
            secondary_pos_duc = []
            all_pos_duc = []
            all_distances = []
            distances_duc = []
            tfidf_similarities = []
            all_tfidf_similarities = []
            average_mmrs = []
            all_average_mmrs = []

            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):

                # def process(example_idx_example):
                #     # print '0'
                #     example = example_idx_example
                if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances:
                    break
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                article_text = ' '.join(raw_article_sents)
                groundtruth_summ_sents = [[
                    sent.strip()
                    for sent in groundtruth_summary_text.strip().split('\n')
                ]]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                rel_sent_indices, doc_sent_indices, doc_sent_lens = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(
                    doc_indices, article_sent_tokens)
                groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                sent_term_matrix = util.get_doc_substituted_tfidf_matrix(
                    tfidf_vectorizer, raw_article_sents, article_text)
                sents_similarities = util.cosine_similarity(
                    sent_term_matrix, sent_term_matrix)
                importances = util.special_squash(
                    util.get_tfidf_importances(tfidf_vectorizer,
                                               raw_article_sents))

                if FLAGS.dataset_name == 'duc_2004':
                    first_k_indices = lambdamart_scores_to_summaries.get_indices_of_first_k_sents_of_each_article(
                        rel_sent_indices, FLAGS.first_k)
                else:
                    first_k_indices = [
                        idx for idx in range(len(raw_article_sents))
                    ]
                article_indices = list(range(len(raw_article_sents)))

                possible_pairs = [
                    x for x in list(itertools.combinations(article_indices, 2))
                ]  # all pairs
                # # # filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_criteria(raw_article_sents, possible_pairs, corefs)
                # if FLAGS.dataset_name == 'duc_2004':
                #     filtered_possible_pairs = [x for x in list(itertools.combinations(first_k_indices, 2))]  # all pairs
                # else:
                #     filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_sent_position(possible_pairs)
                # # removed_pairs = list(set(possible_pairs) - set(filtered_possible_pairs))
                # possible_singles = [(i,) for i in range(len(raw_article_sents))]
                # all_combinations = filtered_possible_pairs + possible_singles
                #
                # all_possible_singles += len(possible_singles)
                # all_possible_pairs[0] += len(possible_pairs)
                # all_filtered_pairs += len(filtered_possible_pairs)
                # all_all_combinations += len(all_combinations)

                # for ssi in groundtruth_similar_source_indices_list:
                #     if len(ssi) > 0:
                #         idx = rel_sent_indices[ssi[0]]
                #         rel_positions_primary.append(idx)
                #         rel_positions_all.append(idx)
                #     if len(ssi) > 1:
                #         idx = rel_sent_indices[ssi[1]]
                #         rel_positions_secondary.append(idx)
                #         rel_positions_all.append(idx)
                #
                #
                #

                # coref_pairs = preprocess_for_lambdamart_no_flags.get_coref_pairs(corefs)
                # # DO OVER LAP PAIRS BETTER
                # overlap_pairs = preprocess_for_lambdamart_no_flags.filter_by_overlap(article_sent_tokens, possible_pairs)
                # either_coref_or_word = list(set(list(coref_pairs) + overlap_pairs))
                #
                # for ssi in groundtruth_similar_source_indices_list:
                #     if len(ssi) == 2:
                #         all_ssi_pairs[0] += 1
                #         do_share_coref = ssi in coref_pairs
                #         do_share_words = ssi in overlap_pairs
                #         if do_share_coref:
                #             ssi_pairs_with_shared_coref[0] += 1
                #         if do_share_words:
                #             ssi_pairs_with_shared_word[0] += 1
                #         if do_share_coref or do_share_words:
                #             ssi_pairs_with_either_coref_or_word[0] += 1
                # all_pairs_with_shared_coref[0] += len(coref_pairs)
                # all_pairs_with_shared_word[0] += len(overlap_pairs)
                # all_pairs_with_either_coref_or_word[0] += len(either_coref_or_word)

                if FLAGS.dataset_name == 'duc_2004':
                    primary_pos_duc.extend([
                        rel_sent_indices[ssi[0]]
                        for ssi in groundtruth_similar_source_indices_list
                        if len(ssi) >= 1
                    ])
                    secondary_pos_duc.extend([
                        rel_sent_indices[ssi[1]]
                        for ssi in groundtruth_similar_source_indices_list
                        if len(ssi) >= 2
                    ])
                    all_pos_duc.extend([
                        max([rel_sent_indices[sent_idx] for sent_idx in ssi])
                        for ssi in groundtruth_similar_source_indices_list
                        if len(ssi) >= 1
                    ])

                for ssi in groundtruth_similar_source_indices_list:
                    for sent_idx in ssi:
                        sent_lens.append(len(article_sent_tokens[sent_idx]))
                    if len(ssi) >= 1:
                        orig_val = ssi[0]
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_primary.extend(vals_to_add)
                    if len(ssi) >= 2:
                        orig_val = ssi[1]
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_secondary.extend(vals_to_add)

                        if FLAGS.dataset_name == 'duc_2004':
                            distances_duc.append(
                                abs(rel_sent_indices[ssi[1]] -
                                    rel_sent_indices[ssi[0]]))

                        tfidf_similarities.append(sents_similarities[ssi[0],
                                                                     ssi[1]])
                        average_mmrs.append(
                            (importances[ssi[0]] + importances[ssi[1]]) / 2)

                for ssi in groundtruth_similar_source_indices_list:
                    if len(ssi) == 1:
                        orig_val = ssi[0]
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_singles.extend(vals_to_add)
                    if len(ssi) >= 2:
                        if doc_sent_indices[ssi[0]] != doc_sent_indices[
                                ssi[1]]:
                            continue
                        orig_val_first = min(ssi[0], ssi[1])
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val_first, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_pairs_first.extend(vals_to_add)
                        orig_val_second = max(ssi[0], ssi[1])
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val_second, rel_sent_indices,
                            doc_sent_indices, doc_sent_lens, raw_article_sents)
                        normalized_positions_pairs_second.extend(vals_to_add)

                # all_normalized_positions_primary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(single[0], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for single in possible_singles]))
                # all_normalized_positions_secondary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(pair[1], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for pair in possible_pairs]))
                all_sent_lens.extend(
                    [len(sent) for sent in article_sent_tokens])
                all_distances.extend([
                    abs(rel_sent_indices[pair[1]] - rel_sent_indices[pair[0]])
                    for pair in possible_pairs
                ])
                all_tfidf_similarities.extend([
                    sents_similarities[pair[0], pair[1]]
                    for pair in possible_pairs
                ])
                all_average_mmrs.extend([
                    (importances[pair[0]] + importances[pair[1]]) / 2
                    for pair in possible_pairs
                ])

                # if FLAGS.dataset_name == 'duc_2004':
                #     rel_pos_single = [rel_sent_indices[single[0]] for single in possible_singles]
                #     rel_pos_pair = [[rel_sent_indices[pair[0]], rel_sent_indices[pair[1]]] for pair in possible_pairs]
                #     all_pos.extend(rel_pos_single)
                #     all_pos.extend([max(pair) for pair in rel_pos_pair])
                # else:
                #     all_pos.extend(util.flatten_list_of_lists(possible_singles))
                #     all_pos.extend([max(pair) for pair in possible_pairs])
                # y.extend([1 if single in groundtruth_similar_source_indices_list else 0 for single in possible_singles])
                # y.extend([1 if pair in groundtruth_similar_source_indices_list else 0 for pair in possible_pairs])

                # actual_total[0] += 1

            # # p = Pool(144)
            # # list(tqdm(p.imap(process, example_generator), total=total))
            #
            # # print 'Possible_singles\tPossible_pairs\tFiltered_pairs\tAll_combinations: \n%.2f\t%.2f\t%.2f\t%.2f' % (all_possible_singles*1./actual_total, \
            # #     all_possible_pairs*1./actual_total, all_filtered_pairs*1./actual_total, all_all_combinations*1./actual_total)
            # #
            # # # print 'Relative positions of groundtruth source sentences in document:\nPrimary\tSecondary\tBoth\n%.2f\t%.2f\t%.2f' % (np.mean(rel_positions_primary), np.mean(rel_positions_secondary), np.mean(rel_positions_all))
            # #
            # # print 'SSI Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \
            # #       % (ssi_pairs_with_shared_coref[0]*100./all_ssi_pairs[0], ssi_pairs_with_shared_word[0]*100./all_ssi_pairs[0], ssi_pairs_with_either_coref_or_word[0]*100./all_ssi_pairs[0])
            # # print 'All Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \
            # #       % (all_pairs_with_shared_coref[0]*100./all_possible_pairs[0], all_pairs_with_shared_word[0]*100./all_possible_pairs[0], all_pairs_with_either_coref_or_word[0]*100./all_possible_pairs[0])
            #
            # # hist_all_pos = np.histogram(all_pos, bins=max(all_pos)+1)
            # # print 'Histogram of all sent positions: ', util.hist_as_pdf_str(hist_all_pos)
            # # min_sent_len = min(sent_lens)
            # # hist_sent_lens = np.histogram(sent_lens, bins=max(sent_lens)-min_sent_len+1)
            # # print 'min, max sent lens:', min_sent_len, max(sent_lens)
            # # print 'Histogram of sent lens: ', util.hist_as_pdf_str(hist_sent_lens)
            # # min_all_sent_len = min(all_sent_lens)
            # # hist_all_sent_lens = np.histogram(all_sent_lens, bins=max(all_sent_lens)-min_all_sent_len+1)
            # # print 'min, max all sent lens:', min_all_sent_len, max(all_sent_lens)
            # # print 'Histogram of all sent lens: ', util.hist_as_pdf_str(hist_all_sent_lens)
            #
            # # print 'Pearsons r, p value', pearsonr(all_pos, y)
            # # fig, ax1 = plt.subplots(nrows=1)
            # # plt.scatter(all_pos, y)
            # # pp = PdfPages(os.path.join('stuff/plots', FLAGS.dataset_name + '_position_scatter.pdf'))
            # # plt.savefig(pp, format='pdf',bbox_inches='tight')
            # # plt.show()
            # # pp.close()
            #
            # # if FLAGS.dataset_name == 'duc_2004':
            # #     plot_positions(primary_pos_duc, secondary_pos_duc, all_pos_duc)
            #
            # normalized_positions_all = normalized_positions_primary + normalized_positions_secondary
            # # plot_histogram(normalized_positions_primary, num_bins=100)
            # # plot_histogram(normalized_positions_secondary, num_bins=100)
            # # plot_histogram(normalized_positions_all, num_bins=100)
            #
            # sent_lens_together = [sent_lens, all_sent_lens]
            # # plot_histogram(sent_lens_together, pdf=True, start_at_0=True, max_val=70)
            #
            # if FLAGS.dataset_name == 'duc_2004':
            #     distances = distances_duc
            # sent_distances_together = [distances, all_distances]
            # # plot_histogram(sent_distances_together, pdf=True, start_at_0=True, max_val=100)
            #
            # tfidf_similarities_together = [tfidf_similarities, all_tfidf_similarities]
            # # plot_histogram(tfidf_similarities_together, pdf=True, num_bins=100)
            #
            # average_mmrs_together = [average_mmrs, all_average_mmrs]
            # # plot_histogram(average_mmrs_together, pdf=True, num_bins=100)
            #
            # normalized_positions_primary_together = [normalized_positions_primary, bin_values]
            # normalized_positions_secondary_together = [normalized_positions_secondary, bin_values]
            # # plot_histogram(normalized_positions_primary_together, pdf=True, num_bins=100)
            # # plot_histogram(normalized_positions_secondary_together, pdf=True, num_bins=100)
            #
            #
            # list_of_hist_pairs = [
            #     {
            #         'lst': normalized_positions_primary_together,
            #         'pdf': True,
            #         'num_bins': 100,
            #         'y_lim': 3.9,
            #         'y_label': FLAGS.dataset_name,
            #         'x_label': 'Sent position (primary)'
            #     },
            #     {
            #         'lst': normalized_positions_secondary_together,
            #         'pdf': True,
            #         'num_bins': 100,
            #         'y_lim': 3.9,
            #         'x_label': 'Sent position (secondary)'
            #     },
            #     {
            #         'lst': sent_distances_together,
            #         'pdf': True,
            #         'start_at_0': True,
            #         'max_val': 100,
            #         'x_label': 'Sent distance'
            #     },
            #     {
            #         'lst': sent_lens_together,
            #         'pdf': True,
            #         'start_at_0': True,
            #         'max_val': 70,
            #         'x_label': 'Sent length'
            #     },
            #     {
            #         'lst': average_mmrs_together,
            #         'pdf': True,
            #         'num_bins': 100,
            #         'x_label': 'Average TF-IDF importance'
            #     }
            # ]

            normalized_positions_pairs_together = [
                normalized_positions_pairs_first,
                normalized_positions_pairs_second
            ]
            list_of_hist_pairs = [
                {
                    'lst': [normalized_positions_singles],
                    'pdf': True,
                    'num_bins': 100,
                    # 'y_lim': 3.9,
                    'x_lim': 1.0,
                    'y_label': FLAGS.dataset_name,
                    'x_label': 'Sent Position (Singles)',
                    'legend_labels': ['Primary']
                },
                {
                    'lst': normalized_positions_pairs_together,
                    'pdf': True,
                    'num_bins': 100,
                    # 'y_lim': 3.9,
                    'x_lim': 1.0,
                    'x_label': 'Sent Position (Pairs)',
                    'legend_labels': ['Primary', 'Secondary']
                }
            ]

            all_lists_of_histogram_pairs.append(list_of_hist_pairs)
        with open(plot_data_file, 'w') as f:
            cPickle.dump(all_lists_of_histogram_pairs, f)
    else:
        with open(plot_data_file) as f:
            all_lists_of_histogram_pairs = cPickle.load(f)
    plot_histograms(all_lists_of_histogram_pairs)