Exemple #1
0
    def text_generator(self, example_generator):
        """Generates article and abstract text from tf.Example.

        Args:
            example_generator: a generator of tf.Examples from file. See data.example_generator"""
        # i = 0

        while True:
            e = next(example_generator) # e is a tf.Example
            try:
                names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string_list')]
                if self._hps.dataset_name == 'duc_2004':
                    names_to_types[2] = ('summary_text', 'string_list')

                raw_article_sents, ssi, groundtruth_summary_sents = util.unpack_tf_example(
                    e, names_to_types)
                groundtruth_summary_text = '\n'.join(groundtruth_summary_sents)
                article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
                article_text = ' '.join([' '.join(sent) for sent in article_sent_tokens])
                if self._hps.dataset_name == 'duc_2004':
                    abstract_sentences = [['<s> ' + sent.strip() + ' </s>' for sent in
                                          gt_summ_text.strip().split('\n')] for gt_summ_text in groundtruth_summary_text]
                    abstract_sentences = [abs_sents[:max_dec_sents] for abs_sents in abstract_sentences]
                    abstract_texts = [' '.join(abs_sents) for abs_sents in abstract_sentences]
                else:
                    abstract_sentences = ['<s> ' + sent.strip() + ' </s>' for sent in groundtruth_summary_text.strip().split('\n')]
                    abstract_sentences = abstract_sentences[:max_dec_sents]
                    abstract_texts = [' '.join(abstract_sentences)]
                if 'doc_indices' not in e.features.feature or len(e.features.feature['doc_indices'].bytes_list.value) == 0:
                    num_words = len(article_text.split())
                    doc_indices_text = '0 ' * num_words
                else:
                    doc_indices_text = e.features.feature['doc_indices'].bytes_list.value[0]
                sentence_limit = 1 if self._hps.singles_and_pairs == 'singles' else 2
                ssi = util.enforce_sentence_limit(ssi, sentence_limit)
                ssi = ssi[:max_dec_sents]
                ssi = util.make_ssi_chronological(ssi)
            except:
                logging.error('Failed to get article or abstract from example')
                raise
            if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1
                logging.warning('Found an example with empty article text. Skipping it.\n*********************************************')
            elif len(article_text.strip().split()) < 3 and self._hps.skip_with_less_than_3:
                print('Article has less than 3 tokens, so skipping\n*********************************************')
            elif len(abstract_texts[0].strip().split()) < 3 and self._hps.skip_with_less_than_3:
                print('Abstract has less than 3 tokens, so skipping\n*********************************************')
            else:
                yield (article_text, abstract_texts, doc_indices_text, raw_article_sents, ssi)
def write_to_lambdamart_examples_to_file(ex):
    example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs = ex
    print(example_idx)
    # example_idx += 1
    temp_in_path = os.path.join(temp_in_dir, '%06d.txt' % example_idx)
    if not FLAGS.start_over and os.path.exists(temp_in_path):
        return
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    if doc_indices is None:
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    doc_indices = [int(doc_idx) for doc_idx in doc_indices]
    if len(doc_indices) != len(
            util.flatten_list_of_lists(article_sent_tokens)):
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    rel_sent_indices, _, _ = get_rel_sent_indices(doc_indices,
                                                  article_sent_tokens)
    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sents = [[
        sent.strip() for sent in groundtruth_summary_text.strip().split('\n')
    ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]
    # summ_sent_tokens = [sent.strip().split() for sent in summary_text.strip().split('\n')]

    if FLAGS.dataset_name == 'duc_2004':
        first_k_indices = get_indices_of_first_k_sents_of_each_article(
            rel_sent_indices, FLAGS.first_k)
    else:
        first_k_indices = [idx for idx in range(len(raw_article_sents))]

    if importance:
        get_instances(example_idx, raw_article_sents, article_sent_tokens,
                      corefs, rel_sent_indices, first_k_indices, temp_in_path,
                      temp_out_path, single_feat_len, pair_feat_len,
                      singles_and_pairs)
Exemple #3
0
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, _, _ = ex
    print(example_idx)

    # Read example from dataset
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(example, names_to_types)
    article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]]
    groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]]

    if FLAGS.upper_bound:
        # If upper bound, then get the groundtruth singletons/pairs
        replaced_ssi_list = util.replace_empty_ssis(enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(replaced_ssi_list)
        summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        # Generates summary based on BERT output. This is an extractive summary.
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx)
        similar_source_indices_list_trunc = similar_source_indices_list[:ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:ssi_length_extractive]
        if example_idx <= 1:
            summary_sent_tokens = [sent.split(' ') for sent in summary_sents_for_html_trunc]
            extracted_sents_in_article_html = html_highlight_sents_in_article(summary_sent_tokens, similar_source_indices_list_trunc,
                                            article_sent_tokens, doc_indices=doc_indices)

            groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(
                                            groundtruth_summ_sent_tokens,
                                           article_sent_tokens, None, sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(groundtruth_summ_sent_tokens, groundtruth_ssi_list,
                                            article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices)

            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            ssi_functions.write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
Exemple #4
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    source_dir = os.path.join(data_dir, FLAGS.dataset_name)
    source_files = sorted(glob.glob(source_dir + '/' + FLAGS.dataset_split + '*'))

    total = len(source_files) * 1000 if ('cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name or 'xsum' in FLAGS.dataset_name) else len(source_files)
    example_generator = data.example_generator(source_dir + '/' + FLAGS.dataset_split + '*', True, False,
                                               should_check_valid=False)

    for example_idx, example in enumerate(tqdm(example_generator, total=total)):
        raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
            example, names_to_types)
        article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
        groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]]
        if doc_indices is None:
            doc_indices = [0] * len(util.flatten_list_of_lists(article_sent_tokens))
        doc_indices = [int(doc_idx) for doc_idx in doc_indices]
        rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens)
        groundtruth_similar_source_indices_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, FLAGS.sentence_limit)
Exemple #5
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.singles_and_pairs == 'singles':
        FLAGS.sentence_limit = 1
    else:
        FLAGS.sentence_limit = 2

    if FLAGS.dataset_name == 'all':
        dataset_names = ['cnn_dm', 'xsum', 'duc_2004']
    else:
        dataset_names = [FLAGS.dataset_name]

    for dataset_name in dataset_names:
        FLAGS.dataset_name = dataset_name

        source_dir = os.path.join(data_dir, dataset_name)

        if FLAGS.dataset_split == 'all':
            if dataset_name == 'duc_2004':
                dataset_splits = ['test']
            else:
                # dataset_splits = ['val_test', 'test', 'val', 'train']
                dataset_splits = ['test', 'val', 'train']
        else:
            dataset_splits = [FLAGS.dataset_split]

        for dataset_split in dataset_splits:
            if dataset_split == 'val_test':
                source_dataset_split = 'val'
            else:
                source_dataset_split = dataset_split

            source_files = sorted(
                glob.glob(source_dir + '/' + source_dataset_split + '*'))

            total = len(source_files) * 1000
            example_generator = data.example_generator(
                source_dir + '/' + source_dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            out_dir = os.path.join('data', 'bert', dataset_name,
                                   FLAGS.singles_and_pairs, 'input')
            util.create_dirs(out_dir)

            writer = open(os.path.join(out_dir, dataset_split) + '.tsv', 'wb')
            header_list = [
                'should_merge', 'sent1', 'sent2', 'example_idx', 'inst_id',
                'ssi'
            ]
            writer.write(('\t'.join(header_list) + '\n').encode())
            inst_id = 0
            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent, whitespace=True)
                    for sent in raw_article_sents
                ]
                groundtruth_summ_sents = [[
                    sent.strip()
                    for sent in groundtruth_summary_text.strip().split('\n')
                ]]
                if dataset_name != 'duc_2004' or doc_indices is None or (
                        dataset_name != 'duc_2004' and len(doc_indices) != len(
                            util.flatten_list_of_lists(article_sent_tokens))):
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                rel_sent_indices, _, _ = ssi_functions.get_rel_sent_indices(
                    doc_indices, article_sent_tokens)
                similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                possible_pairs = [
                    x for x in list(
                        itertools.combinations(
                            list(range(len(raw_article_sents))), 2))
                ]  # all pairs
                possible_pairs = filter_pairs_by_sent_position(
                    possible_pairs, rel_sent_indices=rel_sent_indices)
                possible_singles = [(i, )
                                    for i in range(len(raw_article_sents))]
                positives = [ssi for ssi in similar_source_indices_list]

                if dataset_split == 'test' or dataset_split == 'val_test':
                    if FLAGS.singles_and_pairs == 'singles':
                        possible_combinations = possible_singles
                    else:
                        possible_combinations = possible_pairs + possible_singles
                    negatives = [
                        ssi for ssi in possible_combinations
                        if not (ssi in positives or ssi[::-1] in positives)
                    ]

                    for ssi_idx, ssi in enumerate(positives):
                        if len(ssi) == 0:
                            continue
                        if chronological_ssi and len(ssi) >= 2:
                            if ssi[0] > ssi[1]:
                                ssi = (min(ssi), max(ssi))
                        writer.write(
                            get_string_bert_example(raw_article_sents, ssi, 1,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1
                    for ssi in negatives:
                        writer.write(
                            get_string_bert_example(raw_article_sents, ssi, 0,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1

                else:
                    positive_sents = list(
                        set(util.flatten_list_of_lists(positives)))
                    negative_pairs = [
                        pair for pair in possible_pairs
                        if not any(i in positive_sents for i in pair)
                    ]
                    negative_singles = [
                        sing for sing in possible_singles
                        if not sing[0] in positive_sents
                    ]
                    random_negative_pairs = np.random.permutation(
                        len(negative_pairs)).tolist()
                    random_negative_singles = np.random.permutation(
                        len(negative_singles)).tolist()

                    for ssi in similar_source_indices_list:
                        if len(ssi) == 0:
                            continue
                        if chronological_ssi and len(ssi) >= 2:
                            if ssi[0] > ssi[1]:
                                ssi = (min(ssi), max(ssi))
                        is_pair = len(ssi) == 2
                        writer.write(
                            get_string_bert_example(raw_article_sents, ssi, 1,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1

                        # False sentence single/pair
                        if is_pair:
                            if len(random_negative_pairs) == 0:
                                continue
                            negative_indices = negative_pairs[
                                random_negative_pairs.pop()]
                        else:
                            if len(random_negative_singles) == 0:
                                continue
                            negative_indices = negative_singles[
                                random_negative_singles.pop()]
                        article_lcs_paths = None
                        writer.write(
                            get_string_bert_example(raw_article_sents,
                                                    negative_indices, 0,
                                                    example_idx,
                                                    inst_id).encode())
                        inst_id += 1
Exemple #6
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.summarizer == 'all':
        summary_methods = list(summarizers.keys())
    else:
        summary_methods = [FLAGS.summarizer]
    if FLAGS.dataset_name == 'all':
        dataset_names = datasets
    else:
        dataset_names = [FLAGS.dataset_name]

    sheets_strs = []
    for summary_method in summary_methods:
        summary_fn = summarizers[summary_method]
        for dataset_name in dataset_names:
            FLAGS.dataset_name = dataset_name

            original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else ''
            vocab = Vocab('logs/vocab' + '_' + original_dataset_name,
                          50000)  # create a vocabulary

            source_dir = os.path.join(data_dir, dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + FLAGS.dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in dataset_name or 'newsroom' in dataset_name
                or 'xsum' in dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + FLAGS.dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            if dataset_name == 'duc_2004':
                abs_source_dir = os.path.join(
                    os.path.expanduser('~') + '/data/tf_data/with_coref',
                    dataset_name)
                abs_example_generator = data.example_generator(
                    abs_source_dir + '/' + FLAGS.dataset_split + '*',
                    True,
                    False,
                    should_check_valid=False)
                abs_names_to_types = [('abstract', 'string_list')]

            triplet_ssi_list = []
            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if dataset_name == 'duc_2004':
                    abs_example = next(abs_example_generator)
                    groundtruth_summary_texts = util.unpack_tf_example(
                        abs_example, abs_names_to_types)
                    groundtruth_summary_texts = groundtruth_summary_texts[0]
                    groundtruth_summ_sents_list = [[
                        sent.strip() for sent in data.abstract2sents(abstract)
                    ] for abstract in groundtruth_summary_texts]

                else:
                    groundtruth_summary_texts = [groundtruth_summary_text]
                    groundtruth_summ_sents_list = []
                    for groundtruth_summary_text in groundtruth_summary_texts:
                        groundtruth_summ_sents = [
                            sent.strip() for sent in
                            groundtruth_summary_text.strip().split('\n')
                        ]
                        groundtruth_summ_sents_list.append(
                            groundtruth_summ_sents)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                log_dir = os.path.join(log_root,
                                       dataset_name + '_' + summary_method)
                dec_dir = os.path.join(log_dir, 'decoded')
                ref_dir = os.path.join(log_dir, 'reference')
                util.create_dirs(dec_dir)
                util.create_dirs(ref_dir)

                parser = PlaintextParser.from_string(
                    ' '.join(raw_article_sents), Tokenizer("english"))
                summarizer = summary_fn()

                summary = summarizer(
                    parser.document,
                    5)  #Summarize the document with 5 sentences
                summary = [str(sentence) for sentence in summary]

                summary_tokenized = []
                for sent in summary:
                    summary_tokenized.append(sent.lower())

                rouge_functions.write_for_rouge(groundtruth_summ_sents_list,
                                                summary_tokenized,
                                                example_idx,
                                                ref_dir,
                                                dec_dir,
                                                log=False)

                decoded_sent_tokens = [
                    sent.split() for sent in summary_tokenized
                ]
                sentence_limit = 2
                sys_ssi_list, _, _ = get_simple_source_indices_list(
                    decoded_sent_tokens, article_sent_tokens, vocab,
                    sentence_limit, min_matched_tokens)
                triplet_ssi_list.append(
                    (groundtruth_similar_source_indices_list, sys_ssi_list,
                     -1))

            print('Evaluating Lambdamart model F1 score...')
            suffix = util.all_sent_selection_eval(triplet_ssi_list)
            print(suffix)

            results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir)
            print(("Results_dict: ", results_dict))
            sheets_str = rouge_functions.rouge_log(results_dict,
                                                   log_dir,
                                                   suffix=suffix)
            sheets_strs.append(dataset_name + '_' + summary_method + '\n' +
                               sheets_str)

    for sheets_str in sheets_strs:
        print(sheets_str + '\n')
def convert_article_to_lambdamart_features(ex):
    # example_idx += 1
    # if num_instances != -1 and example_idx >= num_instances:
    #     break
    example, example_idx, single_feat_len, pair_feat_len, singles_and_pairs, out_path = ex
    print(example_idx)
    raw_article_sents, similar_source_indices_list, summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    if doc_indices is None:
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    doc_indices = [int(doc_idx) for doc_idx in doc_indices]
    if len(doc_indices) != len(
            util.flatten_list_of_lists(article_sent_tokens)):
        doc_indices = [0] * len(
            util.flatten_list_of_lists(article_sent_tokens))
    rel_sent_indices, _, _ = util.get_rel_sent_indices(doc_indices,
                                                       article_sent_tokens)
    if FLAGS.singles_and_pairs == 'singles':
        sentence_limit = 1
    else:
        sentence_limit = 2
    similar_source_indices_list = util.enforce_sentence_limit(
        similar_source_indices_list, sentence_limit)
    summ_sent_tokens = [
        sent.strip().split() for sent in summary_text.strip().split('\n')
    ]

    # sent_term_matrix = util.get_tfidf_matrix(raw_article_sents)
    article_text = ' '.join(raw_article_sents)
    sent_term_matrix = util.get_doc_substituted_tfidf_matrix(
        tfidf_vectorizer, raw_article_sents, article_text, pca)
    doc_vector = np.mean(sent_term_matrix, axis=0)

    out_str = ''
    # ssi_idx_cur_inst_id = defaultdict(int)
    instances = []

    if importance:
        importances = util.special_squash(
            util.get_tfidf_importances(tfidf_vectorizer, raw_article_sents,
                                       pca))
        possible_pairs = [
            x for x in list(
                itertools.combinations(list(range(len(raw_article_sents))), 2))
        ]  # all pairs
        if FLAGS.use_pair_criteria:
            possible_pairs = filter_pairs_by_criteria(raw_article_sents,
                                                      possible_pairs, corefs)
        if FLAGS.sent_position_criteria:
            possible_pairs = filter_pairs_by_sent_position(
                possible_pairs, rel_sent_indices)
        possible_singles = [(i, ) for i in range(len(raw_article_sents))]
        possible_combinations = possible_pairs + possible_singles
        positives = [ssi for ssi in similar_source_indices_list]
        negatives = [
            ssi for ssi in possible_combinations
            if not (ssi in positives or ssi[::-1] in positives)
        ]

        negative_pairs = [
            x for x in possible_pairs
            if not (x in similar_source_indices_list
                    or x[::-1] in similar_source_indices_list)
        ]
        negative_singles = [
            x for x in possible_singles
            if not (x in similar_source_indices_list
                    or x[::-1] in similar_source_indices_list)
        ]
        random_negative_pairs = np.random.permutation(
            len(negative_pairs)).tolist()
        random_negative_singles = np.random.permutation(
            len(negative_singles)).tolist()

        qid = example_idx
        for similar_source_indices in positives:
            # True sentence single/pair
            relevance = 1
            features = get_features(similar_source_indices, sent_term_matrix,
                                    article_sent_tokens, rel_sent_indices,
                                    single_feat_len, pair_feat_len,
                                    importances, singles_and_pairs)
            if features is None:
                continue
            instances.append(
                Lambdamart_Instance(features, relevance, qid,
                                    similar_source_indices))
            a = 0

            if FLAGS.dataset_name == 'xsum' and FLAGS.special_xsum_balance:
                neg_relevance = 0
                num_negative = 4
                if FLAGS.singles_and_pairs == 'singles':
                    num_neg_singles = num_negative
                    num_neg_pairs = 0
                else:
                    num_neg_singles = num_negative / 2
                    num_neg_pairs = num_negative / 2
                for _ in range(num_neg_singles):
                    if len(random_negative_singles) == 0:
                        continue
                    negative_indices = negative_singles[
                        random_negative_singles.pop()]
                    neg_features = get_features(negative_indices,
                                                sent_term_matrix,
                                                article_sent_tokens,
                                                rel_sent_indices,
                                                single_feat_len, pair_feat_len,
                                                importances, singles_and_pairs)
                    if neg_features is None:
                        continue
                    instances.append(
                        Lambdamart_Instance(neg_features, neg_relevance, qid,
                                            negative_indices))
                for _ in range(num_neg_pairs):
                    if len(random_negative_pairs) == 0:
                        continue
                    negative_indices = negative_pairs[
                        random_negative_pairs.pop()]
                    neg_features = get_features(negative_indices,
                                                sent_term_matrix,
                                                article_sent_tokens,
                                                rel_sent_indices,
                                                single_feat_len, pair_feat_len,
                                                importances, singles_and_pairs)
                    if neg_features is None:
                        continue
                    instances.append(
                        Lambdamart_Instance(neg_features, neg_relevance, qid,
                                            negative_indices))
            elif balance:
                # False sentence single/pair
                is_pair = len(similar_source_indices) == 2
                if is_pair:
                    if len(random_negative_pairs) == 0:
                        continue
                    negative_indices = negative_pairs[
                        random_negative_pairs.pop()]
                else:
                    if len(random_negative_singles) == 0:
                        continue
                    negative_indices = negative_singles[
                        random_negative_singles.pop()]
                neg_relevance = 0
                neg_features = get_features(negative_indices, sent_term_matrix,
                                            article_sent_tokens,
                                            rel_sent_indices, single_feat_len,
                                            pair_feat_len, importances,
                                            singles_and_pairs)
                if neg_features is None:
                    continue
                instances.append(
                    Lambdamart_Instance(neg_features, neg_relevance, qid,
                                        negative_indices))
        if not balance:
            for negative_indices in negatives:
                neg_relevance = 0
                neg_features = get_features(negative_indices, sent_term_matrix,
                                            article_sent_tokens,
                                            single_feat_len, pair_feat_len,
                                            importances, singles_and_pairs)
                if neg_features is None:
                    continue
                instances.append(
                    Lambdamart_Instance(neg_features, neg_relevance, qid,
                                        negative_indices))

    sorted_instances = sorted(instances,
                              key=lambda x: (x.qid, x.source_indices))
    assign_inst_ids(sorted_instances)
    if FLAGS.lr:
        return sorted_instances
    else:
        for instance in sorted_instances:
            lambdamart_str = format_to_lambdamart(instance, single_feat_len)
            out_str += lambdamart_str + '\n'
        with open(os.path.join(out_path, '%06d.txt' % example_idx), 'wb') as f:
            f.write(out_str)
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, _, _ = ex
    print(example_idx)
    # example_idx += 1
    qid = example_idx
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    if FLAGS.dataset_name == 'duc_2004':
        groundtruth_summ_sents = [[
            sent.strip() for sent in gt_summ_text.strip().split('\n')
        ] for gt_summ_text in groundtruth_summary_text]
    else:
        groundtruth_summ_sents = [[
            sent.strip()
            for sent in groundtruth_summary_text.strip().split('\n')
        ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]

    if FLAGS.upper_bound:
        replaced_ssi_list = util.replace_empty_ssis(
            enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(
            replaced_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    elif FLAGS.lead:
        lead_ssi_list = [(idx, ) for idx in list(
            range(util.average_sents_for_dataset[FLAGS.dataset_name]))]
        lead_ssi_list = lead_ssi_list[:len(
            raw_article_sents
        )]  # make sure the sentence indices don't go past the total number of sentences in the article
        selected_article_sent_indices = util.flatten_list_of_lists(
            lead_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = lead_ssi_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(
            article_sent_tokens, qid_ssi_to_importances, example_idx)
        similar_source_indices_list_trunc = similar_source_indices_list[:
                                                                        ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:
                                                              ssi_length_extractive]
        if example_idx <= 100:
            summary_sent_tokens = [
                sent.split(' ') for sent in summary_sents_for_html_trunc
            ]
            extracted_sents_in_article_html = html_highlight_sents_in_article(
                summary_sent_tokens,
                similar_source_indices_list_trunc,
                article_sent_tokens,
                doc_indices=doc_indices)
            # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx)

            groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(
                groundtruth_summ_sent_tokens, article_sent_tokens, None,
                sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(
                groundtruth_summ_sent_tokens,
                groundtruth_ssi_list,
                article_sent_tokens,
                lcs_paths_list=lcs_paths_list,
                article_lcs_paths_list=article_lcs_paths_list,
                doc_indices=doc_indices)
            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents,
                                    example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list,
            similar_source_indices_list, ssi_length_extractive)
Exemple #9
0
    def text_generator(self, example_generator):
        """Generates article and abstract text from tf.Example.

        Args:
            example_generator: a generator of tf.Examples from file. See data.example_generator"""
        # i = 0
        while True:
            # i += 1
            e = next(example_generator)  # e is a tf.Example
            abstract_texts = []
            raw_article_sents = []
            # if self._hps.pg_mmr or '_sent' in self._hps.dataset_name:
            try:
                # names_to_types = [('raw_article_sents', 'string_list'), ('similar_source_indices', 'delimited_list_of_tuples'), ('summary_text', 'string'), ('corefs', 'json'), ('article_lcs_paths_list', 'delimited_list_of_list_of_lists')]
                names_to_types = [('raw_article_sents', 'string_list'),
                                  ('similar_source_indices',
                                   'delimited_list_of_tuples'),
                                  ('summary_text', 'string_list'),
                                  ('corefs', 'json'),
                                  ('article_lcs_paths_list',
                                   'delimited_list_of_list_of_lists')]
                if self._hps.dataset_name == 'duc_2004':
                    names_to_types[2] = ('summary_text', 'string_list')

                # raw_article_sents, ssi, groundtruth_summary_text, corefs, article_lcs_paths_list = util.unpack_tf_example(
                #     e, names_to_types)
                raw_article_sents, ssi, groundtruth_summary_sents, corefs, article_lcs_paths_list = util.unpack_tf_example(
                    e, names_to_types)
                groundtruth_summary_text = '\n'.join(groundtruth_summary_sents)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                article_text = ' '.join(
                    [' '.join(sent) for sent in article_sent_tokens])
                if self._hps.dataset_name == 'duc_2004':
                    abstract_sentences = [[
                        '<s> ' + sent.strip() + ' </s>'
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                    abstract_sentences = [
                        abs_sents[:max_dec_sents]
                        for abs_sents in abstract_sentences
                    ]
                    abstract_texts = [
                        ' '.join(abs_sents) for abs_sents in abstract_sentences
                    ]
                else:
                    abstract_sentences = [
                        '<s> ' + sent.strip() + ' </s>' for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]
                    abstract_sentences = abstract_sentences[:max_dec_sents]
                    abstract_texts = [' '.join(abstract_sentences)]
                if 'doc_indices' not in e.features.feature or len(
                        e.features.feature['doc_indices'].bytes_list.value
                ) == 0:
                    num_words = len(article_text.split())
                    doc_indices_text = '0 ' * num_words
                else:
                    doc_indices_text = e.features.feature[
                        'doc_indices'].bytes_list.value[0]
                sentence_limit = 1 if self._hps.singles_and_pairs == 'singles' else 2
                ssi = util.enforce_sentence_limit(ssi, sentence_limit)
                ssi = ssi[:max_dec_sents]
                article_lcs_paths_list = util.enforce_sentence_limit(
                    article_lcs_paths_list, sentence_limit)
                article_lcs_paths_list = article_lcs_paths_list[:max_dec_sents]
                ssi, article_lcs_paths_list = util.make_ssi_chronological(
                    ssi, article_lcs_paths_list)
            except:
                logging.error('Failed to get article or abstract from example')
                raise
                # continue
            # else:
            #     try:
            #         article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files
            #         for abstract in e.features.feature['abstract'].bytes_list.value:
            #             abstract_texts.append(abstract) # the abstract text was saved under the key 'abstract' in the data files
            #         if 'doc_indices' not in e.features.feature or len(e.features.feature['doc_indices'].bytes_list.value) == 0:
            #             num_words = len(article_text.split())
            #             doc_indices_text = '0 ' * num_words
            #         else:
            #             doc_indices_text = e.features.feature['doc_indices'].bytes_list.value[0]
            #         for sent in e.features.feature['raw_article_sents'].bytes_list.value:
            #             raw_article_sents.append(sent) # the abstract text was saved under the key 'abstract' in the data files
            #         ssi = None
            #         article_lcs_paths_list = None
            #     except ValueError:
            #         logging.error('Failed to get article or abstract from example\n*********************************************')
            #         raise
            #         # continue
            if len(
                    article_text
            ) == 0:  # See https://github.com/abisee/pointer-generator/issues/1
                logging.warning(
                    'Found an example with empty article text. Skipping it.\n*********************************************'
                )
            elif len(article_text.strip().split()
                     ) < 3 and self._hps.skip_with_less_than_3:
                print(
                    'Article has less than 3 tokens, so skipping\n*********************************************'
                )
            elif len(abstract_texts[0].strip().split()
                     ) < 3 and self._hps.skip_with_less_than_3:
                print(
                    'Abstract has less than 3 tokens, so skipping\n*********************************************'
                )
            else:
                # print i
                yield (article_text, abstract_texts, doc_indices_text,
                       raw_article_sents, ssi, article_lcs_paths_list)
Exemple #10
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    out_dir = os.path.join(
        os.path.expanduser('~') + '/data/kaiqiang_data', FLAGS.dataset_name)
    if FLAGS.mode == 'write':
        util.create_dirs(out_dir)
        if FLAGS.dataset_name == 'duc_2004':
            dataset_splits = ['test']
        elif FLAGS.dataset_split == 'all':
            dataset_splits = ['test', 'val', 'train']
        else:
            dataset_splits = [FLAGS.dataset_split]

        for dataset_split in dataset_splits:

            if dataset_split == 'test':
                ssi_data_path = os.path.join(
                    'logs/%s_bert_both_sentemb_artemb_plushidden' %
                    FLAGS.dataset_name, 'ssi.pkl')
                print(util.bcolors.OKGREEN +
                      "Loading SSI from BERT at %s" % ssi_data_path +
                      util.bcolors.ENDC)
                with open(ssi_data_path) as f:
                    ssi_triple_list = pickle.load(f)

            source_dir = os.path.join(data_dir, FLAGS.dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
                or 'xsum' in FLAGS.dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            out_document_path = os.path.join(out_dir,
                                             dataset_split + '.Ndocument')
            out_summary_path = os.path.join(out_dir,
                                            dataset_split + '.Nsummary')
            out_example_idx_path = os.path.join(out_dir,
                                                dataset_split + '.Nexampleidx')

            doc_writer = open(out_document_path, 'w')
            if dataset_split != 'test':
                sum_writer = open(out_summary_path, 'w')
            ex_idx_writer = open(out_example_idx_path, 'w')

            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances:
                    break
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if FLAGS.dataset_name == 'duc_2004':
                    groundtruth_summ_sents = [[
                        sent.strip()
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                else:
                    groundtruth_summ_sents = [[
                        sent.strip() for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens)

                if dataset_split == 'test':
                    if example_idx >= len(ssi_triple_list):
                        raise Exception(
                            'Len of ssi list (%d) is less than number of examples (>=%d)'
                            % (len(ssi_triple_list), example_idx))
                    ssi_length_extractive = ssi_triple_list[example_idx][2]
                    if ssi_length_extractive > 1:
                        a = 0
                    ssi = ssi_triple_list[example_idx][1]
                    ssi = ssi[:ssi_length_extractive]
                    groundtruth_similar_source_indices_list = ssi
                else:
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list,
                        FLAGS.sentence_limit)

                for ssi_idx, ssi in enumerate(
                        groundtruth_similar_source_indices_list):
                    if len(ssi) == 0:
                        continue
                    my_article = ' '.join(util.reorder(raw_article_sents, ssi))
                    doc_writer.write(my_article + '\n')
                    if dataset_split != 'test':
                        sum_writer.write(groundtruth_summ_sents[0][ssi_idx] +
                                         '\n')
                    ex_idx_writer.write(str(example_idx) + '\n')
    elif FLAGS.mode == 'evaluate':
        summary_dir = '/home/logan/data/kaiqiang_data/logan_ACL/trained_on_' + FLAGS.train_dataset + '/' + FLAGS.dataset_name
        out_summary_path = os.path.join(summary_dir, 'test' + 'Summary.txt')
        out_example_idx_path = os.path.join(out_dir, 'test' + '.Nexampleidx')
        decode_dir = 'logs/kaiqiang_%s_trainedon%s' % (FLAGS.dataset_name,
                                                       FLAGS.train_dataset)
        rouge_ref_dir = os.path.join(decode_dir, 'reference')
        rouge_dec_dir = os.path.join(decode_dir, 'decoded')
        util.create_dirs(rouge_ref_dir)
        util.create_dirs(rouge_dec_dir)

        def num_lines_in_file(file_path):
            with open(file_path) as f:
                num_lines = sum(1 for line in f)
            return num_lines

        def process_example(sents, ex_idx, groundtruth_summ_sents):
            final_decoded_words = []
            for sent in sents:
                final_decoded_words.extend(sent.split(' '))
            rouge_functions.write_for_rouge(groundtruth_summ_sents,
                                            None,
                                            ex_idx,
                                            rouge_ref_dir,
                                            rouge_dec_dir,
                                            decoded_words=final_decoded_words,
                                            log=False)

        num_lines_summary = num_lines_in_file(out_summary_path)
        num_lines_example_indices = num_lines_in_file(out_example_idx_path)
        if num_lines_summary != num_lines_example_indices:
            raise Exception(
                'Num lines summary != num lines example indices: (%d, %d)' %
                (num_lines_summary, num_lines_example_indices))

        source_dir = os.path.join(data_dir, FLAGS.dataset_name)
        example_generator = data.example_generator(source_dir + '/' + 'test' +
                                                   '*',
                                                   True,
                                                   False,
                                                   should_check_valid=False)

        sum_writer = open(out_summary_path)
        ex_idx_writer = open(out_example_idx_path)
        prev_ex_idx = 0
        sents = []

        for line_idx in tqdm(range(num_lines_summary)):
            line = sum_writer.readline()
            ex_idx = int(ex_idx_writer.readline())

            if ex_idx == prev_ex_idx:
                sents.append(line)
            else:
                example = example_generator.next()
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if FLAGS.dataset_name == 'duc_2004':
                    groundtruth_summ_sents = [[
                        sent.strip()
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                else:
                    groundtruth_summ_sents = [[
                        sent.strip() for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]]
                process_example(sents, ex_idx, groundtruth_summ_sents)
                prev_ex_idx = ex_idx
                sents = [line]

        example = example_generator.next()
        raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
            example, names_to_types)
        if FLAGS.dataset_name == 'duc_2004':
            groundtruth_summ_sents = [[
                sent.strip() for sent in gt_summ_text.strip().split('\n')
            ] for gt_summ_text in groundtruth_summary_text]
        else:
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
        process_example(sents, ex_idx, groundtruth_summ_sents)

        print("Now starting ROUGE eval...")
        if FLAGS.dataset_name == 'xsum':
            l_param = 100
        else:
            l_param = 100
        results_dict = rouge_functions.rouge_eval(rouge_ref_dir,
                                                  rouge_dec_dir,
                                                  l_param=l_param)
        rouge_functions.rouge_log(results_dict, decode_dir)

    else:
        raise Exception('mode flag was not evaluate or write.')
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents)
            else:
                gt_ssi, sys_ssi, ext_len = ssi_list[example_idx]
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                if gt_ssi != groundtruth_similar_source_indices_list:
                    print(
                        'Warning: Example %d has different groundtruth source indices: '
                        + str(groundtruth_similar_source_indices_list) +
                        ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = ''
            for ssi_idx, ssi in enumerate(sys_ssi):
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     FLAGS.batch_size, hps, self._vocab)

                decoded_words, decoded_output, best_hyp = decode_example(
                    self._sess, self._model, self._vocab, batch, example_idx,
                    hps)
                best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 1000:
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>'

                if len(final_decoded_words) >= 100:
                    break

            if example_idx < 1000:
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
Exemple #12
0
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        attn_vis_idx = 0
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
            groundtruth_summ_sent_tokens = [
                sent.split(' ') for sent in groundtruth_summ_sents[0]
            ]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                sys_alp_list = groundtruth_article_lcs_paths_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                sys_ssi, sys_alp_list = util.replace_empty_ssis(
                    sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list)
            else:
                gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[
                    example_idx]
                sys_alp_list = ssi_functions.list_labels_from_probs(
                    sys_token_probs_list, FLAGS.tag_threshold)
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                # if gt_ssi != groundtruth_similar_source_indices_list:
                #     raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = '<u>System Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(sys_ssi):
                # selected_article_lcs_paths = None
                selected_article_lcs_paths = sys_alp_list[ssi_idx]
                ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                    ssi, selected_article_lcs_paths)
                selected_article_lcs_paths = [selected_article_lcs_paths]
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     selected_article_lcs_paths,
                                     FLAGS.batch_size, hps, self._vocab)

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string
                # article_withunks = data.show_art_oovs(original_article, self._vocab) # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

                if FLAGS.first_intact and ssi_idx == 0:
                    decoded_words = selected_article_text.strip().split()
                    decoded_output = selected_article_text
                else:
                    decoded_words, decoded_output, best_hyp = decode_example(
                        self._sess, self._model, self._vocab, batch,
                        example_idx, hps)
                    best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += highlighted_html + '<br>'

                if FLAGS.attn_vis and example_idx < 200:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx
                    )  # write info to .json file for visualization tool
                    attn_vis_idx += 1

                if len(final_decoded_words) >= 100:
                    break

            gt_ssi_list, gt_alp_list = util.replace_empty_ssis(
                groundtruth_similar_source_indices_list,
                raw_article_sents,
                sys_alp_list=groundtruth_article_lcs_paths_list)
            highlight_html_gt = '<u>Reference Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(gt_ssi_list):
                selected_article_lcs_paths = gt_alp_list[ssi_idx]
                try:
                    ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                        ssi, selected_article_lcs_paths)
                except:
                    util.print_vars(ssi, example_idx,
                                    selected_article_lcs_paths)
                    raise
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [
                        groundtruth_summ_sent_tokens[ssi_idx]
                    ]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_gt += highlighted_html + '<br>'

            if example_idx < 100 or (example_idx >= 2000
                                     and example_idx < 2100):
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                highlight_html_total = ssi_functions.put_html_in_two_columns(
                    highlight_html_total, highlight_html_gt)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            # if example_idx % 100 == 0:
            #     attn_dir = os.path.join(self._decode_dir, 'attn_vis_data')
            #     attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            # if FLAGS.attn_vis:
            #     self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            if FLAGS.dataset_name == 'xsum':
                l_param = 100
            else:
                l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings = ex
    print(example_idx)
    # example_idx += 1
    qid = example_idx
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sent_tokens = []
    groundtruth_summ_sents = [[
        sent.strip() for sent in groundtruth_summary_text.strip().split('\n')
    ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]

    if FLAGS.upper_bound:
        replaced_ssi_list = util.replace_empty_ssis(
            enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(
            replaced_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, \
            article_lcs_paths_list, token_probs_list = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx, qid_ssi_to_token_scores_and_mappings)
        similar_source_indices_list_trunc = similar_source_indices_list[:
                                                                        ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:
                                                              ssi_length_extractive]
        if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100):
            summary_sent_tokens = [
                sent.split(' ') for sent in summary_sents_for_html_trunc
            ]
            if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0:
                lcs_paths_list_param = copy.deepcopy(article_lcs_paths_list)
            else:
                lcs_paths_list_param = None
            extracted_sents_in_article_html = html_highlight_sents_in_article(
                summary_sent_tokens,
                similar_source_indices_list_trunc,
                article_sent_tokens,
                doc_indices=doc_indices,
                lcs_paths_list=lcs_paths_list_param)
            # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx)

            groundtruth_ssi_list, gt_lcs_paths_list, gt_article_lcs_paths_list, gt_smooth_article_paths_list = get_simple_source_indices_list(
                groundtruth_summ_sent_tokens, article_sent_tokens, None,
                sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(
                groundtruth_summ_sent_tokens,
                groundtruth_ssi_list,
                article_sent_tokens,
                lcs_paths_list=gt_lcs_paths_list,
                article_lcs_paths_list=gt_smooth_article_paths_list,
                doc_indices=doc_indices)

            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            # all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html
            write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents,
                                    example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list,
            similar_source_indices_list, ssi_length_extractive,
            token_probs_list)
Exemple #14
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.dataset_name == 'all':
        dataset_names = ['cnn_dm', 'xsum', 'duc_2004']
    else:
        dataset_names = [FLAGS.dataset_name]

    if not os.path.exists(plot_data_file):
        all_lists_of_histogram_pairs = []
        for dataset_name in dataset_names:
            FLAGS.dataset_name = dataset_name

            if dataset_name == 'duc_2004':
                dataset_splits = ['test']
            elif FLAGS.dataset_split == 'all':
                dataset_splits = ['test', 'val', 'train']
            else:
                dataset_splits = [FLAGS.dataset_split]

            ssi_list = []
            for dataset_split in dataset_splits:

                ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name,
                                        dataset_split + '_ssi.pkl')

                with open(ssi_path) as f:
                    ssi_list.extend(pickle.load(f))

                if FLAGS.dataset_name == 'duc_2004':
                    for abstract_idx in [1, 2, 3]:
                        ssi_path = os.path.join(
                            ssi_dir, FLAGS.dataset_name, dataset_split +
                            '_ssi_' + str(abstract_idx) + '.pkl')
                        with open(ssi_path) as f:
                            temp_ssi_list = pickle.load(f)
                        ssi_list.extend(temp_ssi_list)

            ssi_2d = util.flatten_list_of_lists(ssi_list)

            num_extracted = [
                len(ssi) for ssi in util.flatten_list_of_lists(ssi_list)
            ]
            hist_num_extracted = np.histogram(num_extracted,
                                              bins=6,
                                              range=(0, 5))
            print(hist_num_extracted)
            print('Histogram of number of sentences merged: ' +
                  util.hist_as_pdf_str(hist_num_extracted))

            distances = [
                abs(ssi[0] - ssi[1]) for ssi in ssi_2d if len(ssi) >= 2
            ]
            print('Distance between sentences (mean, median): ',
                  np.mean(distances), np.median(distances))
            hist_dist = np.histogram(distances, bins=max(distances))
            print('Histogram of distances: ' + util.hist_as_pdf_str(hist_dist))

            summ_sent_idx_to_number_of_source_sents = [[], [], [], [], [], [],
                                                       [], [], [], []]
            for ssi in ssi_list:
                for summ_sent_idx, source_indices in enumerate(ssi):
                    if len(source_indices) == 0 or summ_sent_idx >= len(
                            summ_sent_idx_to_number_of_source_sents):
                        continue
                    num_sents = len(source_indices)
                    if num_sents > 2:
                        num_sents = 2
                    summ_sent_idx_to_number_of_source_sents[
                        summ_sent_idx].append(num_sents)
            print(
                "Number of source sents for summary sentence indices (Is the first summary sent more likely to match with a singleton or a pair?):"
            )
            for summ_sent_idx, list_of_numbers_of_source_sents in enumerate(
                    summ_sent_idx_to_number_of_source_sents):
                if len(list_of_numbers_of_source_sents) == 0:
                    percent_singleton = 0.
                else:
                    percent_singleton = list_of_numbers_of_source_sents.count(
                        1) * 1. / len(list_of_numbers_of_source_sents)
                    percent_pair = list_of_numbers_of_source_sents.count(
                        2) * 1. / len(list_of_numbers_of_source_sents)
                print str(percent_singleton) + '\t',
            print ''
            for summ_sent_idx, list_of_numbers_of_source_sents in enumerate(
                    summ_sent_idx_to_number_of_source_sents):
                if len(list_of_numbers_of_source_sents) == 0:
                    percent_pair = 0.
                else:
                    percent_singleton = list_of_numbers_of_source_sents.count(
                        1) * 1. / len(list_of_numbers_of_source_sents)
                    percent_pair = list_of_numbers_of_source_sents.count(
                        2) * 1. / len(list_of_numbers_of_source_sents)
                print str(percent_pair) + '\t',
            print ''

            primary_pos = [ssi[0] for ssi in ssi_2d if len(ssi) >= 1]
            secondary_pos = [ssi[1] for ssi in ssi_2d if len(ssi) >= 2]
            all_pos = [max(ssi) for ssi in ssi_2d if len(ssi) >= 1]

            # if FLAGS.dataset_name != 'duc_2004':
            #     plot_positions(primary_pos, secondary_pos, all_pos)

            if FLAGS.dataset_split == 'all':
                glob_string = '*.bin'
            else:
                glob_string = dataset_splits[0]

            print('Loading TFIDF vectorizer')
            with open(tfidf_vec_path, 'rb') as f:
                tfidf_vectorizer = pickle.load(f)

            source_dir = os.path.join(data_dir, FLAGS.dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + glob_string + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
                or 'xsum' in FLAGS.dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + glob_string + '*',
                True,
                False,
                should_check_valid=False)

            all_possible_singles = 0
            all_possible_pairs = [0]
            all_filtered_pairs = 0
            all_all_combinations = 0
            all_ssi_pairs = [0]
            ssi_pairs_with_shared_coref = [0]
            ssi_pairs_with_shared_word = [0]
            ssi_pairs_with_either_coref_or_word = [0]
            all_pairs_with_shared_coref = [0]
            all_pairs_with_shared_word = [0]
            all_pairs_with_either_coref_or_word = [0]
            actual_total = [0]
            rel_positions_primary = []
            rel_positions_secondary = []
            rel_positions_all = []
            sent_lens = []
            all_sent_lens = []
            all_pos = []
            y = []
            normalized_positions_primary = []
            normalized_positions_secondary = []
            all_normalized_positions_primary = []
            all_normalized_positions_secondary = []
            normalized_positions_singles = []
            normalized_positions_pairs_first = []
            normalized_positions_pairs_second = []
            primary_pos_duc = []
            secondary_pos_duc = []
            all_pos_duc = []
            all_distances = []
            distances_duc = []
            tfidf_similarities = []
            all_tfidf_similarities = []
            average_mmrs = []
            all_average_mmrs = []

            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):

                # def process(example_idx_example):
                #     # print '0'
                #     example = example_idx_example
                if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances:
                    break
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                article_text = ' '.join(raw_article_sents)
                groundtruth_summ_sents = [[
                    sent.strip()
                    for sent in groundtruth_summary_text.strip().split('\n')
                ]]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                rel_sent_indices, doc_sent_indices, doc_sent_lens = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(
                    doc_indices, article_sent_tokens)
                groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                sent_term_matrix = util.get_doc_substituted_tfidf_matrix(
                    tfidf_vectorizer, raw_article_sents, article_text)
                sents_similarities = util.cosine_similarity(
                    sent_term_matrix, sent_term_matrix)
                importances = util.special_squash(
                    util.get_tfidf_importances(tfidf_vectorizer,
                                               raw_article_sents))

                if FLAGS.dataset_name == 'duc_2004':
                    first_k_indices = lambdamart_scores_to_summaries.get_indices_of_first_k_sents_of_each_article(
                        rel_sent_indices, FLAGS.first_k)
                else:
                    first_k_indices = [
                        idx for idx in range(len(raw_article_sents))
                    ]
                article_indices = list(range(len(raw_article_sents)))

                possible_pairs = [
                    x for x in list(itertools.combinations(article_indices, 2))
                ]  # all pairs
                # # # filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_criteria(raw_article_sents, possible_pairs, corefs)
                # if FLAGS.dataset_name == 'duc_2004':
                #     filtered_possible_pairs = [x for x in list(itertools.combinations(first_k_indices, 2))]  # all pairs
                # else:
                #     filtered_possible_pairs = preprocess_for_lambdamart_no_flags.filter_pairs_by_sent_position(possible_pairs)
                # # removed_pairs = list(set(possible_pairs) - set(filtered_possible_pairs))
                # possible_singles = [(i,) for i in range(len(raw_article_sents))]
                # all_combinations = filtered_possible_pairs + possible_singles
                #
                # all_possible_singles += len(possible_singles)
                # all_possible_pairs[0] += len(possible_pairs)
                # all_filtered_pairs += len(filtered_possible_pairs)
                # all_all_combinations += len(all_combinations)

                # for ssi in groundtruth_similar_source_indices_list:
                #     if len(ssi) > 0:
                #         idx = rel_sent_indices[ssi[0]]
                #         rel_positions_primary.append(idx)
                #         rel_positions_all.append(idx)
                #     if len(ssi) > 1:
                #         idx = rel_sent_indices[ssi[1]]
                #         rel_positions_secondary.append(idx)
                #         rel_positions_all.append(idx)
                #
                #
                #

                # coref_pairs = preprocess_for_lambdamart_no_flags.get_coref_pairs(corefs)
                # # DO OVER LAP PAIRS BETTER
                # overlap_pairs = preprocess_for_lambdamart_no_flags.filter_by_overlap(article_sent_tokens, possible_pairs)
                # either_coref_or_word = list(set(list(coref_pairs) + overlap_pairs))
                #
                # for ssi in groundtruth_similar_source_indices_list:
                #     if len(ssi) == 2:
                #         all_ssi_pairs[0] += 1
                #         do_share_coref = ssi in coref_pairs
                #         do_share_words = ssi in overlap_pairs
                #         if do_share_coref:
                #             ssi_pairs_with_shared_coref[0] += 1
                #         if do_share_words:
                #             ssi_pairs_with_shared_word[0] += 1
                #         if do_share_coref or do_share_words:
                #             ssi_pairs_with_either_coref_or_word[0] += 1
                # all_pairs_with_shared_coref[0] += len(coref_pairs)
                # all_pairs_with_shared_word[0] += len(overlap_pairs)
                # all_pairs_with_either_coref_or_word[0] += len(either_coref_or_word)

                if FLAGS.dataset_name == 'duc_2004':
                    primary_pos_duc.extend([
                        rel_sent_indices[ssi[0]]
                        for ssi in groundtruth_similar_source_indices_list
                        if len(ssi) >= 1
                    ])
                    secondary_pos_duc.extend([
                        rel_sent_indices[ssi[1]]
                        for ssi in groundtruth_similar_source_indices_list
                        if len(ssi) >= 2
                    ])
                    all_pos_duc.extend([
                        max([rel_sent_indices[sent_idx] for sent_idx in ssi])
                        for ssi in groundtruth_similar_source_indices_list
                        if len(ssi) >= 1
                    ])

                for ssi in groundtruth_similar_source_indices_list:
                    for sent_idx in ssi:
                        sent_lens.append(len(article_sent_tokens[sent_idx]))
                    if len(ssi) >= 1:
                        orig_val = ssi[0]
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_primary.extend(vals_to_add)
                    if len(ssi) >= 2:
                        orig_val = ssi[1]
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_secondary.extend(vals_to_add)

                        if FLAGS.dataset_name == 'duc_2004':
                            distances_duc.append(
                                abs(rel_sent_indices[ssi[1]] -
                                    rel_sent_indices[ssi[0]]))

                        tfidf_similarities.append(sents_similarities[ssi[0],
                                                                     ssi[1]])
                        average_mmrs.append(
                            (importances[ssi[0]] + importances[ssi[1]]) / 2)

                for ssi in groundtruth_similar_source_indices_list:
                    if len(ssi) == 1:
                        orig_val = ssi[0]
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_singles.extend(vals_to_add)
                    if len(ssi) >= 2:
                        if doc_sent_indices[ssi[0]] != doc_sent_indices[
                                ssi[1]]:
                            continue
                        orig_val_first = min(ssi[0], ssi[1])
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val_first, rel_sent_indices, doc_sent_indices,
                            doc_sent_lens, raw_article_sents)
                        normalized_positions_pairs_first.extend(vals_to_add)
                        orig_val_second = max(ssi[0], ssi[1])
                        vals_to_add = get_integral_values_for_histogram(
                            orig_val_second, rel_sent_indices,
                            doc_sent_indices, doc_sent_lens, raw_article_sents)
                        normalized_positions_pairs_second.extend(vals_to_add)

                # all_normalized_positions_primary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(single[0], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for single in possible_singles]))
                # all_normalized_positions_secondary.extend(util.flatten_list_of_lists([get_integral_values_for_histogram(pair[1], rel_sent_indices, doc_sent_indices, doc_sent_lens, raw_article_sents) for pair in possible_pairs]))
                all_sent_lens.extend(
                    [len(sent) for sent in article_sent_tokens])
                all_distances.extend([
                    abs(rel_sent_indices[pair[1]] - rel_sent_indices[pair[0]])
                    for pair in possible_pairs
                ])
                all_tfidf_similarities.extend([
                    sents_similarities[pair[0], pair[1]]
                    for pair in possible_pairs
                ])
                all_average_mmrs.extend([
                    (importances[pair[0]] + importances[pair[1]]) / 2
                    for pair in possible_pairs
                ])

                # if FLAGS.dataset_name == 'duc_2004':
                #     rel_pos_single = [rel_sent_indices[single[0]] for single in possible_singles]
                #     rel_pos_pair = [[rel_sent_indices[pair[0]], rel_sent_indices[pair[1]]] for pair in possible_pairs]
                #     all_pos.extend(rel_pos_single)
                #     all_pos.extend([max(pair) for pair in rel_pos_pair])
                # else:
                #     all_pos.extend(util.flatten_list_of_lists(possible_singles))
                #     all_pos.extend([max(pair) for pair in possible_pairs])
                # y.extend([1 if single in groundtruth_similar_source_indices_list else 0 for single in possible_singles])
                # y.extend([1 if pair in groundtruth_similar_source_indices_list else 0 for pair in possible_pairs])

                # actual_total[0] += 1

            # # p = Pool(144)
            # # list(tqdm(p.imap(process, example_generator), total=total))
            #
            # # print 'Possible_singles\tPossible_pairs\tFiltered_pairs\tAll_combinations: \n%.2f\t%.2f\t%.2f\t%.2f' % (all_possible_singles*1./actual_total, \
            # #     all_possible_pairs*1./actual_total, all_filtered_pairs*1./actual_total, all_all_combinations*1./actual_total)
            # #
            # # # print 'Relative positions of groundtruth source sentences in document:\nPrimary\tSecondary\tBoth\n%.2f\t%.2f\t%.2f' % (np.mean(rel_positions_primary), np.mean(rel_positions_secondary), np.mean(rel_positions_all))
            # #
            # # print 'SSI Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \
            # #       % (ssi_pairs_with_shared_coref[0]*100./all_ssi_pairs[0], ssi_pairs_with_shared_word[0]*100./all_ssi_pairs[0], ssi_pairs_with_either_coref_or_word[0]*100./all_ssi_pairs[0])
            # # print 'All Pair statistics:\nShare_coref\tShare_word\tShare_either\n%.2f\t%.2f\t%.2f' \
            # #       % (all_pairs_with_shared_coref[0]*100./all_possible_pairs[0], all_pairs_with_shared_word[0]*100./all_possible_pairs[0], all_pairs_with_either_coref_or_word[0]*100./all_possible_pairs[0])
            #
            # # hist_all_pos = np.histogram(all_pos, bins=max(all_pos)+1)
            # # print 'Histogram of all sent positions: ', util.hist_as_pdf_str(hist_all_pos)
            # # min_sent_len = min(sent_lens)
            # # hist_sent_lens = np.histogram(sent_lens, bins=max(sent_lens)-min_sent_len+1)
            # # print 'min, max sent lens:', min_sent_len, max(sent_lens)
            # # print 'Histogram of sent lens: ', util.hist_as_pdf_str(hist_sent_lens)
            # # min_all_sent_len = min(all_sent_lens)
            # # hist_all_sent_lens = np.histogram(all_sent_lens, bins=max(all_sent_lens)-min_all_sent_len+1)
            # # print 'min, max all sent lens:', min_all_sent_len, max(all_sent_lens)
            # # print 'Histogram of all sent lens: ', util.hist_as_pdf_str(hist_all_sent_lens)
            #
            # # print 'Pearsons r, p value', pearsonr(all_pos, y)
            # # fig, ax1 = plt.subplots(nrows=1)
            # # plt.scatter(all_pos, y)
            # # pp = PdfPages(os.path.join('stuff/plots', FLAGS.dataset_name + '_position_scatter.pdf'))
            # # plt.savefig(pp, format='pdf',bbox_inches='tight')
            # # plt.show()
            # # pp.close()
            #
            # # if FLAGS.dataset_name == 'duc_2004':
            # #     plot_positions(primary_pos_duc, secondary_pos_duc, all_pos_duc)
            #
            # normalized_positions_all = normalized_positions_primary + normalized_positions_secondary
            # # plot_histogram(normalized_positions_primary, num_bins=100)
            # # plot_histogram(normalized_positions_secondary, num_bins=100)
            # # plot_histogram(normalized_positions_all, num_bins=100)
            #
            # sent_lens_together = [sent_lens, all_sent_lens]
            # # plot_histogram(sent_lens_together, pdf=True, start_at_0=True, max_val=70)
            #
            # if FLAGS.dataset_name == 'duc_2004':
            #     distances = distances_duc
            # sent_distances_together = [distances, all_distances]
            # # plot_histogram(sent_distances_together, pdf=True, start_at_0=True, max_val=100)
            #
            # tfidf_similarities_together = [tfidf_similarities, all_tfidf_similarities]
            # # plot_histogram(tfidf_similarities_together, pdf=True, num_bins=100)
            #
            # average_mmrs_together = [average_mmrs, all_average_mmrs]
            # # plot_histogram(average_mmrs_together, pdf=True, num_bins=100)
            #
            # normalized_positions_primary_together = [normalized_positions_primary, bin_values]
            # normalized_positions_secondary_together = [normalized_positions_secondary, bin_values]
            # # plot_histogram(normalized_positions_primary_together, pdf=True, num_bins=100)
            # # plot_histogram(normalized_positions_secondary_together, pdf=True, num_bins=100)
            #
            #
            # list_of_hist_pairs = [
            #     {
            #         'lst': normalized_positions_primary_together,
            #         'pdf': True,
            #         'num_bins': 100,
            #         'y_lim': 3.9,
            #         'y_label': FLAGS.dataset_name,
            #         'x_label': 'Sent position (primary)'
            #     },
            #     {
            #         'lst': normalized_positions_secondary_together,
            #         'pdf': True,
            #         'num_bins': 100,
            #         'y_lim': 3.9,
            #         'x_label': 'Sent position (secondary)'
            #     },
            #     {
            #         'lst': sent_distances_together,
            #         'pdf': True,
            #         'start_at_0': True,
            #         'max_val': 100,
            #         'x_label': 'Sent distance'
            #     },
            #     {
            #         'lst': sent_lens_together,
            #         'pdf': True,
            #         'start_at_0': True,
            #         'max_val': 70,
            #         'x_label': 'Sent length'
            #     },
            #     {
            #         'lst': average_mmrs_together,
            #         'pdf': True,
            #         'num_bins': 100,
            #         'x_label': 'Average TF-IDF importance'
            #     }
            # ]

            normalized_positions_pairs_together = [
                normalized_positions_pairs_first,
                normalized_positions_pairs_second
            ]
            list_of_hist_pairs = [
                {
                    'lst': [normalized_positions_singles],
                    'pdf': True,
                    'num_bins': 100,
                    # 'y_lim': 3.9,
                    'x_lim': 1.0,
                    'y_label': FLAGS.dataset_name,
                    'x_label': 'Sent Position (Singles)',
                    'legend_labels': ['Primary']
                },
                {
                    'lst': normalized_positions_pairs_together,
                    'pdf': True,
                    'num_bins': 100,
                    # 'y_lim': 3.9,
                    'x_lim': 1.0,
                    'x_label': 'Sent Position (Pairs)',
                    'legend_labels': ['Primary', 'Secondary']
                }
            ]

            all_lists_of_histogram_pairs.append(list_of_hist_pairs)
        with open(plot_data_file, 'w') as f:
            cPickle.dump(all_lists_of_histogram_pairs, f)
    else:
        with open(plot_data_file) as f:
            all_lists_of_histogram_pairs = cPickle.load(f)
    plot_histograms(all_lists_of_histogram_pairs)
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.exp_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.all_actions:
        FLAGS.sent_dataset = True
        FLAGS.ssi_dataset = True
        FLAGS.print_output = True
        FLAGS.highlight = True

    original_dataset_name = 'xsum' if 'xsum' in FLAGS.dataset_name else 'cnn_dm' if (
        'cnn_dm' in FLAGS.dataset_name
        or 'duc_2004' in FLAGS.dataset_name) else ''
    vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name,
                  FLAGS.vocab_size)  # create a vocabulary

    source_dir = os.path.join(data_dir, FLAGS.dataset_name)
    util.create_dirs(html_dir)

    if FLAGS.dataset_split == 'all':
        if FLAGS.dataset_name == 'duc_2004':
            dataset_splits = ['test']
        else:
            dataset_splits = ['test', 'val', 'train']
    else:
        dataset_splits = [FLAGS.dataset_split]
    for dataset_split in dataset_splits:
        source_files = sorted(glob.glob(source_dir + '/' + dataset_split +
                                        '*'))
        if FLAGS.exp_name == 'reference':
            # summary_dir = log_dir + default_exp_name + '/decode_test_' + str(max_enc_steps) + \
            #                 'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/reference'
            # summary_files = sorted(glob.glob(summary_dir + '/*_reference.A.txt'))
            summary_dir = source_dir
            summary_files = source_files
        else:
            if FLAGS.exp_name == 'cnn_dm':
                summary_dir = log_dir + FLAGS.exp_name + '/decode_test_400maxenc_4beam_35mindec_100maxdec_ckpt-238410/decoded'
            else:
                ckpt_folder = util.find_largest_ckpt_folder(log_dir +
                                                            FLAGS.exp_name)
                summary_dir = log_dir + FLAGS.exp_name + '/' + ckpt_folder + '/decoded'
                # summary_dir = log_dir + FLAGS.exp_name + '/decode_test_' + str(max_enc_steps) + \
                #             'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/decoded'
            summary_files = sorted(glob.glob(summary_dir + '/*'))
        if len(summary_files) == 0:
            raise Exception('No files found in %s' % summary_dir)
        example_generator = data.example_generator(source_dir + '/' +
                                                   dataset_split + '*',
                                                   True,
                                                   False,
                                                   is_original=True)
        pros = {
            'annotators': 'dcoref',
            'outputFormat': 'json',
            'timeout': '5000000'
        }
        all_merge_examples = []
        num_extracted_list = []
        distances = []
        relative_distances = []
        html_str = ''
        extracted_sents_in_article_html = ''
        name = FLAGS.dataset_name + '_' + FLAGS.exp_name
        if FLAGS.coreference_replacement:
            name += '_coref'
        highlight_file_name = os.path.join(
            html_dir, FLAGS.dataset_name + '_' + FLAGS.exp_name)
        if FLAGS.consider_stopwords:
            highlight_file_name += '_stopwords'
        if FLAGS.highlight:
            extracted_sents_in_article_html_file = open(
                highlight_file_name + '_extracted_sents.html', 'wb')
        if FLAGS.kaiqiang:
            kaiqiang_article_texts = []
            kaiqiang_abstract_texts = []
            util.create_dirs(kaiqiang_dir)
            kaiqiang_article_file = open(
                os.path.join(
                    kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split +
                    '_' + str(FLAGS.min_matched_tokens) + '_articles.txt'),
                'wb')
            kaiqiang_abstract_file = open(
                os.path.join(
                    kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split +
                    '_' + str(FLAGS.min_matched_tokens) + '_abstracts.txt'),
                'wb')
        if FLAGS.ssi_dataset:
            if FLAGS.tag_tokens:
                with_coref_and_ssi_dir = lambdamart_dir + '_and_tag_tokens'
            else:
                with_coref_and_ssi_dir = lambdamart_dir
            lambdamart_out_dir = os.path.join(with_coref_and_ssi_dir,
                                              FLAGS.dataset_name)
            if FLAGS.sentence_limit == 1:
                lambdamart_out_dir += '_singles'
            if FLAGS.consider_stopwords:
                lambdamart_out_dir += '_stopwords'
            lambdamart_out_full_dir = os.path.join(lambdamart_out_dir, 'all')
            util.create_dirs(lambdamart_out_full_dir)
            lambdamart_writer = open(
                os.path.join(lambdamart_out_full_dir, dataset_split + '.bin'),
                'wb')

        simple_similar_source_indices_list_plus_empty = []
        example_idx = -1
        instance_idx = 0
        total = len(source_files) * 1000 if (
            'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
            or 'xsum' in FLAGS.dataset_name) else len(source_files)
        random_choices = None
        if FLAGS.randomize:
            if FLAGS.dataset_name == 'cnn_dm':
                list_order = np.random.permutation(11490)
                random_choices = list_order[:FLAGS.num_instances]
        for example in tqdm(example_generator, total=total):
            example_idx += 1
            if FLAGS.num_instances != -1 and instance_idx >= FLAGS.num_instances:
                break
            if random_choices is not None and example_idx not in random_choices:
                continue
        # for file_idx in tqdm(range(len(source_files))):
        #     example = get_tf_example(source_files[file_idx])
            article_text = example.features.feature[
                'article'].bytes_list.value[0].decode().lower()
            if FLAGS.exp_name == 'reference':
                summary_text, all_summary_texts = get_summary_from_example(
                    example)
            else:
                summary_text = get_summary_text(summary_files[example_idx])
            article_tokens = split_into_tokens(article_text)
            if 'raw_article_sents' in example.features.feature and len(
                    example.features.feature['raw_article_sents'].bytes_list.
                    value) > 0:
                raw_article_sents = example.features.feature[
                    'raw_article_sents'].bytes_list.value

                raw_article_sents = [
                    sent.decode() for sent in raw_article_sents
                    if sent.decode().strip() != ''
                ]
                article_sent_tokens = [
                    util.process_sent(sent, whitespace=True)
                    for sent in raw_article_sents
                ]
            else:
                # article_text = util.to_unicode(article_text)

                # sent_pros = {'annotators': 'ssplit', 'outputFormat': 'json', 'timeout': '5000000'}
                # sents_result_dict = nlp.annotate(str(article_text), properties=sent_pros)
                # article_sent_tokens = [[token['word'] for token in sent['tokens']] for sent in sents_result_dict['sentences']]

                raw_article_sents = nltk.tokenize.sent_tokenize(article_text)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
            if FLAGS.top_n_sents != -1:
                article_sent_tokens = article_sent_tokens[:FLAGS.top_n_sents]
                raw_article_sents = raw_article_sents[:FLAGS.top_n_sents]
            article_sents = [' '.join(sent) for sent in article_sent_tokens]
            try:
                article_tokens_string = str(' '.join(article_sents))
            except:
                try:
                    article_tokens_string = str(' '.join(
                        [sent.decode('latin-1') for sent in article_sents]))
                except:
                    raise

            if len(article_sent_tokens) == 0:
                continue

            summary_sent_tokens = split_into_sent_tokens(summary_text)
            if 'doc_indices' in example.features.feature and len(
                    example.features.feature['doc_indices'].bytes_list.value
            ) > 0:
                doc_indices_str = example.features.feature[
                    'doc_indices'].bytes_list.value[0].decode()
                if '1' in doc_indices_str:
                    doc_indices = [
                        int(x) for x in doc_indices_str.strip().split()
                    ]
                    rel_sent_positions = importance_features.get_sent_indices(
                        article_sent_tokens, doc_indices)
                else:
                    num_tokens_total = sum(
                        [len(sent) for sent in article_sent_tokens])
                    rel_sent_positions = list(range(len(raw_article_sents)))
                    doc_indices = [0] * num_tokens_total

            else:
                rel_sent_positions = None
                doc_indices = None
                doc_indices_str = None
            if 'corefs' in example.features.feature and len(
                    example.features.feature['corefs'].bytes_list.value) > 0:
                corefs_str = example.features.feature[
                    'corefs'].bytes_list.value[0]
                corefs = json.loads(corefs_str)
            # summary_sent_tokens = limit_to_n_tokens(summary_sent_tokens, 100)

            similar_source_indices_list_plus_empty = []

            simple_similar_source_indices, lcs_paths_list, article_lcs_paths_list, smooth_article_paths_list = ssi_functions.get_simple_source_indices_list(
                summary_sent_tokens,
                article_sent_tokens,
                vocab,
                FLAGS.sentence_limit,
                FLAGS.min_matched_tokens,
                not FLAGS.consider_stopwords,
                lemmatize=FLAGS.lemmatize,
                multiple_ssi=FLAGS.multiple_ssi)

            article_paths_parameter = article_lcs_paths_list if FLAGS.tag_tokens else None
            article_paths_parameter = smooth_article_paths_list if FLAGS.smart_tags else article_paths_parameter
            restricted_source_indices = util.enforce_sentence_limit(
                simple_similar_source_indices, FLAGS.sentence_limit)
            for summ_sent_idx, summ_sent in enumerate(summary_sent_tokens):
                if FLAGS.sent_dataset:
                    if len(restricted_source_indices[summ_sent_idx]) == 0:
                        continue
                    merge_example = get_merge_example(
                        restricted_source_indices[summ_sent_idx],
                        article_sent_tokens, summ_sent, corefs,
                        article_paths_parameter[summ_sent_idx])
                    all_merge_examples.append(merge_example)

            simple_similar_source_indices_list_plus_empty.append(
                simple_similar_source_indices)
            if FLAGS.ssi_dataset:
                summary_text_to_save = [
                    s for s in all_summary_texts
                ] if FLAGS.dataset_name == 'duc_2004' else summary_text
                write_lambdamart_example(simple_similar_source_indices,
                                         raw_article_sents,
                                         summary_text_to_save, corefs_str,
                                         doc_indices_str,
                                         article_paths_parameter,
                                         lambdamart_writer)

            if FLAGS.highlight:
                highlight_article_lcs_paths_list = smooth_article_paths_list if FLAGS.smart_tags else article_lcs_paths_list
                # simple_ssi_plus_empty = [ [s[0] for s in sim_source_ind] for sim_source_ind in simple_similar_source_indices]
                extracted_sents_in_article_html = ssi_functions.html_highlight_sents_in_article(
                    summary_sent_tokens, simple_similar_source_indices,
                    article_sent_tokens, doc_indices, lcs_paths_list,
                    highlight_article_lcs_paths_list)
                extracted_sents_in_article_html_file.write(
                    extracted_sents_in_article_html.encode())
            a = 0

            instance_idx += 1

        if FLAGS.ssi_dataset:
            lambdamart_writer.close()
            if FLAGS.dataset_name == 'cnn_dm' or FLAGS.dataset_name == 'newsroom' or FLAGS.dataset_name == 'xsum':
                chunk_size = 1000
            else:
                chunk_size = 1
            util.chunk_file(dataset_split,
                            lambdamart_out_full_dir,
                            lambdamart_out_dir,
                            chunk_size=chunk_size)

        if FLAGS.sent_dataset:
            with_coref_dir = data_dir + '_and_tag_tokens' if FLAGS.tag_tokens else data_dir
            out_dir = os.path.join(with_coref_dir,
                                   FLAGS.dataset_name + '_sent')
            if FLAGS.sentence_limit == 1:
                out_dir += '_singles'
            if FLAGS.consider_stopwords:
                out_dir += '_stopwords'
            if FLAGS.coreference_replacement:
                out_dir += '_coref'
            if FLAGS.top_n_sents != -1:
                out_dir += '_n=' + str(FLAGS.top_n_sents)
            util.create_dirs(out_dir)
            convert_data.write_with_generator(iter(all_merge_examples),
                                              len(all_merge_examples), out_dir,
                                              dataset_split)

        if FLAGS.print_output:
            # html_str = FLAGS.dataset + ' | ' + FLAGS.exp_name + '<br><br><br>' + html_str
            # save_fusions_to_file(html_str)
            ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name)
            if FLAGS.consider_stopwords:
                ssi_path += '_stopwords'
            util.create_dirs(ssi_path)
            if FLAGS.dataset_name == 'duc_2004' and FLAGS.abstract_idx != 0:
                abstract_idx_str = '_%d' % FLAGS.abstract_idx
            else:
                abstract_idx_str = ''
            with open(
                    os.path.join(
                        ssi_path,
                        dataset_split + '_ssi' + abstract_idx_str + '.pkl'),
                    'wb') as f:
                pickle.dump(simple_similar_source_indices_list_plus_empty, f)

        if FLAGS.kaiqiang:
            # kaiqiang_article_file.write('\n'.join(kaiqiang_article_texts))
            # kaiqiang_abstract_file.write('\n'.join(kaiqiang_abstract_texts))
            kaiqiang_article_file.close()
            kaiqiang_abstract_file.close()
        if FLAGS.highlight:
            extracted_sents_in_article_html_file.close()
        a = 0