def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    source_dir = os.path.join(data_dir, FLAGS.dataset)
    source_files = sorted(glob.glob(source_dir + '/*'))

    for i in range(4):
        ref_dir = os.path.join(log_dir, 'reference_' + str(i), 'reference')
        dec_dir = os.path.join(log_dir, 'reference_' + str(i), 'decoded')
        util.create_dirs(ref_dir)
        util.create_dirs(dec_dir)
        for source_idx, source_file in enumerate(source_files):
            human_summary_texts = get_human_summary_texts(source_file)
            summaries = []
            for summary_text in human_summary_texts:
                summary = data.abstract2sents(summary_text)
                summaries.append(summary)
            candidate = summaries[i]
            references = [
                summaries[idx] for idx in range(len(summaries)) if idx != i
            ]
            rouge_functions.write_for_rouge(references, candidate, source_idx,
                                            ref_dir, dec_dir)

        results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir)
        # print("Results_dict: ", results_dict)
        rouge_functions.rouge_log(results_dict,
                                  os.path.join(log_dir, 'reference_' + str(i)))
Ejemplo n.º 2
0
 def process_example(sents, ex_idx, groundtruth_summ_sents):
     final_decoded_words = []
     for sent in sents:
         final_decoded_words.extend(sent.split(' '))
     rouge_functions.write_for_rouge(groundtruth_summ_sents,
                                     None,
                                     ex_idx,
                                     rouge_ref_dir,
                                     rouge_dec_dir,
                                     decoded_words=final_decoded_words,
                                     log=False)
Ejemplo n.º 3
0
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, _, _ = ex
    print(example_idx)

    # Read example from dataset
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(example, names_to_types)
    article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sents = [[sent.strip() for sent in groundtruth_summary_text.strip().split('\n')]]
    groundtruth_summ_sent_tokens = [sent.split(' ') for sent in groundtruth_summ_sents[0]]

    if FLAGS.upper_bound:
        # If upper bound, then get the groundtruth singletons/pairs
        replaced_ssi_list = util.replace_empty_ssis(enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(replaced_ssi_list)
        summary_sents = [' '.join(sent) for sent in util.reorder(article_sent_tokens, selected_article_sent_indices)]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        # Generates summary based on BERT output. This is an extractive summary.
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx)
        similar_source_indices_list_trunc = similar_source_indices_list[:ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:ssi_length_extractive]
        if example_idx <= 1:
            summary_sent_tokens = [sent.split(' ') for sent in summary_sents_for_html_trunc]
            extracted_sents_in_article_html = html_highlight_sents_in_article(summary_sent_tokens, similar_source_indices_list_trunc,
                                            article_sent_tokens, doc_indices=doc_indices)

            groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(
                                            groundtruth_summ_sent_tokens,
                                           article_sent_tokens, None, sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(groundtruth_summ_sent_tokens, groundtruth_ssi_list,
                                            article_sent_tokens, lcs_paths_list=lcs_paths_list, article_lcs_paths_list=article_lcs_paths_list, doc_indices=doc_indices)

            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            ssi_functions.write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents, example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list, similar_source_indices_list, ssi_length_extractive)
Ejemplo n.º 4
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.summarizer == 'all':
        summary_methods = list(summarizers.keys())
    else:
        summary_methods = [FLAGS.summarizer]
    if FLAGS.dataset_name == 'all':
        dataset_names = datasets
    else:
        dataset_names = [FLAGS.dataset_name]

    sheets_strs = []
    for summary_method in summary_methods:
        summary_fn = summarizers[summary_method]
        for dataset_name in dataset_names:
            FLAGS.dataset_name = dataset_name

            original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else ''
            vocab = Vocab('logs/vocab' + '_' + original_dataset_name,
                          50000)  # create a vocabulary

            source_dir = os.path.join(data_dir, dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + FLAGS.dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in dataset_name or 'newsroom' in dataset_name
                or 'xsum' in dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + FLAGS.dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            if dataset_name == 'duc_2004':
                abs_source_dir = os.path.join(
                    os.path.expanduser('~') + '/data/tf_data/with_coref',
                    dataset_name)
                abs_example_generator = data.example_generator(
                    abs_source_dir + '/' + FLAGS.dataset_split + '*',
                    True,
                    False,
                    should_check_valid=False)
                abs_names_to_types = [('abstract', 'string_list')]

            triplet_ssi_list = []
            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if dataset_name == 'duc_2004':
                    abs_example = next(abs_example_generator)
                    groundtruth_summary_texts = util.unpack_tf_example(
                        abs_example, abs_names_to_types)
                    groundtruth_summary_texts = groundtruth_summary_texts[0]
                    groundtruth_summ_sents_list = [[
                        sent.strip() for sent in data.abstract2sents(abstract)
                    ] for abstract in groundtruth_summary_texts]

                else:
                    groundtruth_summary_texts = [groundtruth_summary_text]
                    groundtruth_summ_sents_list = []
                    for groundtruth_summary_text in groundtruth_summary_texts:
                        groundtruth_summ_sents = [
                            sent.strip() for sent in
                            groundtruth_summary_text.strip().split('\n')
                        ]
                        groundtruth_summ_sents_list.append(
                            groundtruth_summ_sents)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                log_dir = os.path.join(log_root,
                                       dataset_name + '_' + summary_method)
                dec_dir = os.path.join(log_dir, 'decoded')
                ref_dir = os.path.join(log_dir, 'reference')
                util.create_dirs(dec_dir)
                util.create_dirs(ref_dir)

                parser = PlaintextParser.from_string(
                    ' '.join(raw_article_sents), Tokenizer("english"))
                summarizer = summary_fn()

                summary = summarizer(
                    parser.document,
                    5)  #Summarize the document with 5 sentences
                summary = [str(sentence) for sentence in summary]

                summary_tokenized = []
                for sent in summary:
                    summary_tokenized.append(sent.lower())

                rouge_functions.write_for_rouge(groundtruth_summ_sents_list,
                                                summary_tokenized,
                                                example_idx,
                                                ref_dir,
                                                dec_dir,
                                                log=False)

                decoded_sent_tokens = [
                    sent.split() for sent in summary_tokenized
                ]
                sentence_limit = 2
                sys_ssi_list, _, _ = get_simple_source_indices_list(
                    decoded_sent_tokens, article_sent_tokens, vocab,
                    sentence_limit, min_matched_tokens)
                triplet_ssi_list.append(
                    (groundtruth_similar_source_indices_list, sys_ssi_list,
                     -1))

            print('Evaluating Lambdamart model F1 score...')
            suffix = util.all_sent_selection_eval(triplet_ssi_list)
            print(suffix)

            results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir)
            print(("Results_dict: ", results_dict))
            sheets_str = rouge_functions.rouge_log(results_dict,
                                                   log_dir,
                                                   suffix=suffix)
            sheets_strs.append(dataset_name + '_' + summary_method + '\n' +
                               sheets_str)

    for sheets_str in sheets_strs:
        print(sheets_str + '\n')
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, _, _ = ex
    print(example_idx)
    # example_idx += 1
    qid = example_idx
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    if FLAGS.dataset_name == 'duc_2004':
        groundtruth_summ_sents = [[
            sent.strip() for sent in gt_summ_text.strip().split('\n')
        ] for gt_summ_text in groundtruth_summary_text]
    else:
        groundtruth_summ_sents = [[
            sent.strip()
            for sent in groundtruth_summary_text.strip().split('\n')
        ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]

    if FLAGS.upper_bound:
        replaced_ssi_list = util.replace_empty_ssis(
            enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(
            replaced_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    elif FLAGS.lead:
        lead_ssi_list = [(idx, ) for idx in list(
            range(util.average_sents_for_dataset[FLAGS.dataset_name]))]
        lead_ssi_list = lead_ssi_list[:len(
            raw_article_sents
        )]  # make sure the sentence indices don't go past the total number of sentences in the article
        selected_article_sent_indices = util.flatten_list_of_lists(
            lead_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = lead_ssi_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive = generate_summary(
            article_sent_tokens, qid_ssi_to_importances, example_idx)
        similar_source_indices_list_trunc = similar_source_indices_list[:
                                                                        ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:
                                                              ssi_length_extractive]
        if example_idx <= 100:
            summary_sent_tokens = [
                sent.split(' ') for sent in summary_sents_for_html_trunc
            ]
            extracted_sents_in_article_html = html_highlight_sents_in_article(
                summary_sent_tokens,
                similar_source_indices_list_trunc,
                article_sent_tokens,
                doc_indices=doc_indices)
            # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx)

            groundtruth_ssi_list, lcs_paths_list, article_lcs_paths_list = get_simple_source_indices_list(
                groundtruth_summ_sent_tokens, article_sent_tokens, None,
                sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(
                groundtruth_summ_sent_tokens,
                groundtruth_ssi_list,
                article_sent_tokens,
                lcs_paths_list=lcs_paths_list,
                article_lcs_paths_list=article_lcs_paths_list,
                doc_indices=doc_indices)
            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents,
                                    example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list,
            similar_source_indices_list, ssi_length_extractive)
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        total = len(glob.glob(self._batcher._data_path)) * 1000
        pbar = tqdm(total=total)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                logging.info("Output has been saved in %s and %s.",
                             self._rouge_ref_dir, self._rouge_dec_dir)
                if len(os.listdir(self._rouge_ref_dir)) != 0:
                    logging.info("Now starting ROUGE eval...")
                    results_dict = rouge_functions.rouge_eval(
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    rouge_functions.rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            all_original_abstract_sents = batch.all_original_abstracts_sents[0]
            raw_article_sents = batch.raw_article_sents[0]

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            decoded_words, decoded_output, best_hyp = decode_example(
                self._sess, self._model, self._vocab, batch, counter,
                self._batcher._hps)

            if FLAGS.single_pass:
                if counter < 1000:
                    self.write_for_human(raw_article_sents,
                                         all_original_abstract_sents,
                                         decoded_words, counter)
                rouge_functions.write_for_rouge(
                    all_original_abstract_sents,
                    None,
                    counter,
                    self._rouge_ref_dir,
                    self._rouge_dec_dir,
                    decoded_words=decoded_words
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.attn_vis:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, counter
                    )  # write info to .json file for visualization tool

                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
            pbar.update(1)
        pbar.close()
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents)
            else:
                gt_ssi, sys_ssi, ext_len = ssi_list[example_idx]
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                if gt_ssi != groundtruth_similar_source_indices_list:
                    print(
                        'Warning: Example %d has different groundtruth source indices: '
                        + str(groundtruth_similar_source_indices_list) +
                        ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = ''
            for ssi_idx, ssi in enumerate(sys_ssi):
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     FLAGS.batch_size, hps, self._vocab)

                decoded_words, decoded_output, best_hyp = decode_example(
                    self._sess, self._model, self._vocab, batch, example_idx,
                    hps)
                best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 1000:
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>'

                if len(final_decoded_words) >= 100:
                    break

            if example_idx < 1000:
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
Ejemplo n.º 8
0
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        attn_vis_idx = 0
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
            groundtruth_summ_sent_tokens = [
                sent.split(' ') for sent in groundtruth_summ_sents[0]
            ]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                sys_alp_list = groundtruth_article_lcs_paths_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                sys_ssi, sys_alp_list = util.replace_empty_ssis(
                    sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list)
            else:
                gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[
                    example_idx]
                sys_alp_list = ssi_functions.list_labels_from_probs(
                    sys_token_probs_list, FLAGS.tag_threshold)
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                # if gt_ssi != groundtruth_similar_source_indices_list:
                #     raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = '<u>System Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(sys_ssi):
                # selected_article_lcs_paths = None
                selected_article_lcs_paths = sys_alp_list[ssi_idx]
                ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                    ssi, selected_article_lcs_paths)
                selected_article_lcs_paths = [selected_article_lcs_paths]
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     selected_article_lcs_paths,
                                     FLAGS.batch_size, hps, self._vocab)

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string
                # article_withunks = data.show_art_oovs(original_article, self._vocab) # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

                if FLAGS.first_intact and ssi_idx == 0:
                    decoded_words = selected_article_text.strip().split()
                    decoded_output = selected_article_text
                else:
                    decoded_words, decoded_output, best_hyp = decode_example(
                        self._sess, self._model, self._vocab, batch,
                        example_idx, hps)
                    best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += highlighted_html + '<br>'

                if FLAGS.attn_vis and example_idx < 200:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx
                    )  # write info to .json file for visualization tool
                    attn_vis_idx += 1

                if len(final_decoded_words) >= 100:
                    break

            gt_ssi_list, gt_alp_list = util.replace_empty_ssis(
                groundtruth_similar_source_indices_list,
                raw_article_sents,
                sys_alp_list=groundtruth_article_lcs_paths_list)
            highlight_html_gt = '<u>Reference Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(gt_ssi_list):
                selected_article_lcs_paths = gt_alp_list[ssi_idx]
                try:
                    ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                        ssi, selected_article_lcs_paths)
                except:
                    util.print_vars(ssi, example_idx,
                                    selected_article_lcs_paths)
                    raise
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [
                        groundtruth_summ_sent_tokens[ssi_idx]
                    ]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_gt += highlighted_html + '<br>'

            if example_idx < 100 or (example_idx >= 2000
                                     and example_idx < 2100):
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                highlight_html_total = ssi_functions.put_html_in_two_columns(
                    highlight_html_total, highlight_html_gt)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            # if example_idx % 100 == 0:
            #     attn_dir = os.path.join(self._decode_dir, 'attn_vis_data')
            #     attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            # if FLAGS.attn_vis:
            #     self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            if FLAGS.dataset_name == 'xsum':
                l_param = 100
            else:
                l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
Ejemplo n.º 9
0
def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.exp_name == 'extractive':
        for summary_method in summary_methods:
            if not os.path.exists(
                    os.path.join(out_dir, summary_method, 'decoded')):
                os.makedirs(os.path.join(out_dir, summary_method, 'decoded'))
            if not os.path.exists(
                    os.path.join(out_dir, summary_method, 'reference')):
                os.makedirs(os.path.join(out_dir, summary_method, 'reference'))
            print((os.path.join(out_dir, summary_method)))
            method_dir = os.path.join(summaries_dir, summary_method)
            file_names = sorted(
                [name for name in os.listdir(method_dir) if name[0] == 'd'])
            for art_idx, article_name in enumerate(tqdm(file_names)):
                file = os.path.join(method_dir, article_name)
                with open(file, 'rb') as f:
                    lines = f.readlines()
                tokenized_sents = [[
                    token.lower()
                    for token in nltk.tokenize.word_tokenize(line)
                ] for line in lines]
                sentences = [' '.join(sent) for sent in tokenized_sents]
                processed_summary = '\n'.join(sentences)
                out_name = '%06d_decoded.txt' % art_idx
                with open(
                        os.path.join(out_dir, summary_method, 'decoded',
                                     out_name), 'wb') as f:
                    f.write(processed_summary)

                reference_files = glob.glob(
                    os.path.join(ref_dir, '%06d' % art_idx + '*'))
                abstract_sentences = []
                for ref_file in reference_files:
                    with open(ref_file) as f:
                        lines = f.readlines()
                    abstract_sentences.append(lines)
                rouge_functions.write_for_rouge(
                    abstract_sentences, sentences, art_idx,
                    os.path.join(out_dir, summary_method, 'reference'),
                    os.path.join(out_dir, summary_method, 'decoded'))

            results_dict = rouge_functions.rouge_eval(
                ref_dir, os.path.join(out_dir, summary_method, 'decoded'))
            # print("Results_dict: ", results_dict)
            rouge_functions.rouge_log(results_dict,
                                      os.path.join(out_dir, summary_method))

        for summary_method in summary_methods:
            print(summary_method)
        all_results = ''
        for summary_method in summary_methods:
            sheet_results_file = os.path.join(out_dir, summary_method,
                                              "sheets_results.txt")
            with open(sheet_results_file) as f:
                results = f.read()
            all_results += results
        print(all_results)
        a = 0

    else:
        # source_dir = os.path.join(data_dir, FLAGS.dataset)
        log_root = os.path.join('logs', FLAGS.exp_name)
        ckpt_folder = util.find_largest_ckpt_folder(log_root)
        print(ckpt_folder)
        # if os.path.exists(os.path.join(log_dir,FLAGS.exp_name,ckpt_folder)):
        if ckpt_folder != 'decoded':
            summary_dir = os.path.join(log_dir, FLAGS.exp_name, ckpt_folder)
        else:
            summary_dir = os.path.join(log_dir, FLAGS.exp_name)
        ref_dir = os.path.join(summary_dir, 'reference')
        dec_dir = os.path.join(summary_dir, 'decoded')
        summary_files = glob.glob(os.path.join(dec_dir, '*_decoded.txt'))
        # summary_files = glob.glob(os.path.join(log_dir + FLAGS.exp_name, 'test_*.txt.result.summary'))
        # if len(summary_files) > 0 and not os.path.exists(dec_dir):      # reformat files from extract + rewrite
        #     os.makedirs(ref_dir)
        #     os.makedirs(dec_dir)
        #     for summary_file in tqdm(summary_files):
        #         ex_index = extract_digits(os.path.basename(summary_file))
        #         new_file = os.path.join(dec_dir, "%06d_decoded.txt" % ex_index)
        #         shutil.copyfile(summary_file, new_file)
        #
        #     ref_files_to_copy = glob.glob(os.path.join(log_dir, FLAGS.dataset, ckpt_folder, 'reference', '*'))
        #     for file in tqdm(ref_files_to_copy):
        #         basename = os.path.basename(file)
        #         shutil.copyfile(file, os.path.join(ref_dir, basename))

        lengths = []
        for summary_file in tqdm(summary_files):
            with open(summary_file) as f:
                summary = f.read()
            length = len(summary.strip().split())
            lengths.append(length)
        print('Average summary length: %.2f' % np.mean(lengths))

        print('Evaluating on %d files' % len(os.listdir(dec_dir)))
        results_dict = rouge_functions.rouge_eval(ref_dir,
                                                  dec_dir,
                                                  l_param=FLAGS.l_param)
        rouge_functions.rouge_log(results_dict, summary_dir)
def evaluate_example(ex):
    example, example_idx, qid_ssi_to_importances, qid_ssi_to_token_scores_and_mappings = ex
    print(example_idx)
    # example_idx += 1
    qid = example_idx
    raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
        example, names_to_types)
    article_sent_tokens = [
        util.process_sent(sent) for sent in raw_article_sents
    ]
    enforced_groundtruth_ssi_list = util.enforce_sentence_limit(
        groundtruth_similar_source_indices_list, sentence_limit)
    groundtruth_summ_sent_tokens = []
    groundtruth_summ_sents = [[
        sent.strip() for sent in groundtruth_summary_text.strip().split('\n')
    ]]
    groundtruth_summ_sent_tokens = [
        sent.split(' ') for sent in groundtruth_summ_sents[0]
    ]

    if FLAGS.upper_bound:
        replaced_ssi_list = util.replace_empty_ssis(
            enforced_groundtruth_ssi_list, raw_article_sents)
        selected_article_sent_indices = util.flatten_list_of_lists(
            replaced_ssi_list)
        summary_sents = [
            ' '.join(sent) for sent in util.reorder(
                article_sent_tokens, selected_article_sent_indices)
        ]
        similar_source_indices_list = groundtruth_similar_source_indices_list
        ssi_length_extractive = len(similar_source_indices_list)
    else:
        summary_sents, similar_source_indices_list, summary_sents_for_html, ssi_length_extractive, \
            article_lcs_paths_list, token_probs_list = generate_summary(article_sent_tokens, qid_ssi_to_importances, example_idx, qid_ssi_to_token_scores_and_mappings)
        similar_source_indices_list_trunc = similar_source_indices_list[:
                                                                        ssi_length_extractive]
        summary_sents_for_html_trunc = summary_sents_for_html[:
                                                              ssi_length_extractive]
        if example_idx < 100 or (example_idx >= 2000 and example_idx < 2100):
            summary_sent_tokens = [
                sent.split(' ') for sent in summary_sents_for_html_trunc
            ]
            if FLAGS.tag_tokens and FLAGS.tag_loss_wt != 0:
                lcs_paths_list_param = copy.deepcopy(article_lcs_paths_list)
            else:
                lcs_paths_list_param = None
            extracted_sents_in_article_html = html_highlight_sents_in_article(
                summary_sent_tokens,
                similar_source_indices_list_trunc,
                article_sent_tokens,
                doc_indices=doc_indices,
                lcs_paths_list=lcs_paths_list_param)
            # write_highlighted_html(extracted_sents_in_article_html, html_dir, example_idx)

            groundtruth_ssi_list, gt_lcs_paths_list, gt_article_lcs_paths_list, gt_smooth_article_paths_list = get_simple_source_indices_list(
                groundtruth_summ_sent_tokens, article_sent_tokens, None,
                sentence_limit, min_matched_tokens)
            groundtruth_highlighted_html = html_highlight_sents_in_article(
                groundtruth_summ_sent_tokens,
                groundtruth_ssi_list,
                article_sent_tokens,
                lcs_paths_list=gt_lcs_paths_list,
                article_lcs_paths_list=gt_smooth_article_paths_list,
                doc_indices=doc_indices)

            all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html + '<u>Groundtruth Summary</u><br><br>' + groundtruth_highlighted_html
            # all_html = '<u>System Summary</u><br><br>' + extracted_sents_in_article_html
            write_highlighted_html(all_html, html_dir, example_idx)
    rouge_functions.write_for_rouge(groundtruth_summ_sents, summary_sents,
                                    example_idx, ref_dir, dec_dir)
    return (groundtruth_similar_source_indices_list,
            similar_source_indices_list, ssi_length_extractive,
            token_probs_list)