コード例 #1
0
def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    source_dir = os.path.join(data_dir, FLAGS.dataset)
    source_files = sorted(glob.glob(source_dir + '/*'))

    for i in range(4):
        ref_dir = os.path.join(log_dir, 'reference_' + str(i), 'reference')
        dec_dir = os.path.join(log_dir, 'reference_' + str(i), 'decoded')
        util.create_dirs(ref_dir)
        util.create_dirs(dec_dir)
        for source_idx, source_file in enumerate(source_files):
            human_summary_texts = get_human_summary_texts(source_file)
            summaries = []
            for summary_text in human_summary_texts:
                summary = data.abstract2sents(summary_text)
                summaries.append(summary)
            candidate = summaries[i]
            references = [
                summaries[idx] for idx in range(len(summaries)) if idx != i
            ]
            rouge_functions.write_for_rouge(references, candidate, source_idx,
                                            ref_dir, dec_dir)

        results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir)
        # print("Results_dict: ", results_dict)
        rouge_functions.rouge_log(results_dict,
                                  os.path.join(log_dir, 'reference_' + str(i)))
コード例 #2
0
def main(unused_argv):

    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    print('Running statistics on %s' % exp_name)

    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))


    total = len(source_files)*1000
    example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False)

    # Read output of BERT and put into a dictionary with:
    # key=(article idx, source indices {this is a tuple of length 1 or 2, depending on if it is a singleton or pair})
    # value=score
    qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path)
    ex_gen = example_generator_extended(example_generator, total, qid_ssi_to_importances, None, FLAGS.singles_and_pairs)
    print('Creating list')
    ex_list = [ex for ex in ex_gen]

    # # Main function to get results on all test examples
    # pool = mp.Pool(mp.cpu_count())
    # ssi_list = list(tqdm(pool.imap(evaluate_example, ex_list), total=total))
    # pool.close()

    # Main function to get results on all test examples
    ssi_list = list(map(evaluate_example, ex_list))

    # save ssi_list
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'wb') as f:
        pickle.dump(ssi_list, f)
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
        ssi_list = pickle.load(f)
    print('Evaluating BERT model F1 score...')
    suffix = util.all_sent_selection_eval(ssi_list)
    print('Evaluating ROUGE...')
    results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir, l_param=l_param)
    rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix)

    ssis_restricted = [ssi_triple[1][:ssi_triple[2]] for ssi_triple in ssi_list]
    ssi_lens = [len(source_indices) for source_indices in util.flatten_list_of_lists(ssis_restricted)]
    num_singles = ssi_lens.count(1)
    num_pairs = ssi_lens.count(2)
    print ('Percent singles/pairs: %.2f %.2f' % (num_singles*100./len(ssi_lens), num_pairs*100./len(ssi_lens)))

    util.print_execution_time(start_time)
コード例 #3
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.summarizer == 'all':
        summary_methods = list(summarizers.keys())
    else:
        summary_methods = [FLAGS.summarizer]
    if FLAGS.dataset_name == 'all':
        dataset_names = datasets
    else:
        dataset_names = [FLAGS.dataset_name]

    sheets_strs = []
    for summary_method in summary_methods:
        summary_fn = summarizers[summary_method]
        for dataset_name in dataset_names:
            FLAGS.dataset_name = dataset_name

            original_dataset_name = 'xsum' if 'xsum' in dataset_name else 'cnn_dm' if 'cnn_dm' in dataset_name or 'duc_2004' in dataset_name else ''
            vocab = Vocab('logs/vocab' + '_' + original_dataset_name,
                          50000)  # create a vocabulary

            source_dir = os.path.join(data_dir, dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + FLAGS.dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in dataset_name or 'newsroom' in dataset_name
                or 'xsum' in dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + FLAGS.dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            if dataset_name == 'duc_2004':
                abs_source_dir = os.path.join(
                    os.path.expanduser('~') + '/data/tf_data/with_coref',
                    dataset_name)
                abs_example_generator = data.example_generator(
                    abs_source_dir + '/' + FLAGS.dataset_split + '*',
                    True,
                    False,
                    should_check_valid=False)
                abs_names_to_types = [('abstract', 'string_list')]

            triplet_ssi_list = []
            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if dataset_name == 'duc_2004':
                    abs_example = next(abs_example_generator)
                    groundtruth_summary_texts = util.unpack_tf_example(
                        abs_example, abs_names_to_types)
                    groundtruth_summary_texts = groundtruth_summary_texts[0]
                    groundtruth_summ_sents_list = [[
                        sent.strip() for sent in data.abstract2sents(abstract)
                    ] for abstract in groundtruth_summary_texts]

                else:
                    groundtruth_summary_texts = [groundtruth_summary_text]
                    groundtruth_summ_sents_list = []
                    for groundtruth_summary_text in groundtruth_summary_texts:
                        groundtruth_summ_sents = [
                            sent.strip() for sent in
                            groundtruth_summary_text.strip().split('\n')
                        ]
                        groundtruth_summ_sents_list.append(
                            groundtruth_summ_sents)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                    groundtruth_similar_source_indices_list,
                    FLAGS.sentence_limit)

                log_dir = os.path.join(log_root,
                                       dataset_name + '_' + summary_method)
                dec_dir = os.path.join(log_dir, 'decoded')
                ref_dir = os.path.join(log_dir, 'reference')
                util.create_dirs(dec_dir)
                util.create_dirs(ref_dir)

                parser = PlaintextParser.from_string(
                    ' '.join(raw_article_sents), Tokenizer("english"))
                summarizer = summary_fn()

                summary = summarizer(
                    parser.document,
                    5)  #Summarize the document with 5 sentences
                summary = [str(sentence) for sentence in summary]

                summary_tokenized = []
                for sent in summary:
                    summary_tokenized.append(sent.lower())

                rouge_functions.write_for_rouge(groundtruth_summ_sents_list,
                                                summary_tokenized,
                                                example_idx,
                                                ref_dir,
                                                dec_dir,
                                                log=False)

                decoded_sent_tokens = [
                    sent.split() for sent in summary_tokenized
                ]
                sentence_limit = 2
                sys_ssi_list, _, _ = get_simple_source_indices_list(
                    decoded_sent_tokens, article_sent_tokens, vocab,
                    sentence_limit, min_matched_tokens)
                triplet_ssi_list.append(
                    (groundtruth_similar_source_indices_list, sys_ssi_list,
                     -1))

            print('Evaluating Lambdamart model F1 score...')
            suffix = util.all_sent_selection_eval(triplet_ssi_list)
            print(suffix)

            results_dict = rouge_functions.rouge_eval(ref_dir, dec_dir)
            print(("Results_dict: ", results_dict))
            sheets_str = rouge_functions.rouge_log(results_dict,
                                                   log_dir,
                                                   suffix=suffix)
            sheets_strs.append(dataset_name + '_' + summary_method + '\n' +
                               sheets_str)

    for sheets_str in sheets_strs:
        print(sheets_str + '\n')
def main(unused_argv):
    # def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    print('Running statistics on %s' % exp_name)

    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))
    ex_sents = ['single .', 'sentence .']
    article_text = ' '.join(ex_sents)
    sent_term_matrix = util.get_doc_substituted_tfidf_matrix(
        tfidf_vectorizer, ex_sents, article_text, pca)
    if FLAGS.singles_and_pairs == 'pairs':
        single_feat_len = 0
    else:
        single_feat_len = len(
            get_single_sent_features(0, sent_term_matrix,
                                     [['single', '.'], ['sentence', '.']],
                                     [0, 0], 0))
    if FLAGS.singles_and_pairs == 'singles':
        pair_feat_len = 0
    else:
        pair_feat_len = len(
            get_pair_sent_features([0, 1], sent_term_matrix,
                                   [['single', '.'], ['sentence', '.']],
                                   [0, 0], [0, 0]))

    total = len(source_files
                ) * 1000 if 'cnn' or 'newsroom' in dataset_articles else len(
                    source_files)
    example_generator = data.example_generator(source_dir + '/' +
                                               dataset_split + '*',
                                               True,
                                               False,
                                               should_check_valid=False)

    if FLAGS.mode == 'write_to_file':
        ex_gen = example_generator_extended(example_generator, total,
                                            single_feat_len, pair_feat_len,
                                            FLAGS.singles_and_pairs)
        print('Creating list')
        ex_list = [ex for ex in ex_gen]
        print('Converting...')
        # if len(sys.argv) > 1 and sys.argv[1] == '-m':
        list(futures.map(write_to_lambdamart_examples_to_file, ex_list))
        # else:
        #     instances_list = []
        #     for ex in tqdm(ex_list):
        #         instances_list.append(write_to_lambdamart_examples_to_file(ex))

        file_names = sorted(glob.glob(os.path.join(temp_in_dir, '*')))
        instances_str = ''
        for file_name in tqdm(file_names):
            with open(file_name) as f:
                instances_str += f.read()
        with open(temp_in_path, 'wb') as f:
            f.write(instances_str)

    # RUN LAMBDAMART SCORING COMMAND HERE

    if FLAGS.mode == 'generate_summaries':
        qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path)
        ex_gen = example_generator_extended(example_generator, total,
                                            qid_ssi_to_importances,
                                            pair_feat_len,
                                            FLAGS.singles_and_pairs)
        print('Creating list')
        ex_list = [ex for ex in ex_gen]
        ssi_list = list(futures.map(evaluate_example, ex_list))

        # save ssi_list
        with open(os.path.join(my_log_dir, 'ssi.pkl'), 'w') as f:
            pickle.dump(ssi_list, f)
        with open(os.path.join(my_log_dir, 'ssi.pkl')) as f:
            ssi_list = pickle.load(f)
        print('Evaluating Lambdamart model F1 score...')
        suffix = util.all_sent_selection_eval(ssi_list)
        #
        # # for ex in tqdm(ex_list, total=total):
        # #     load_and_evaluate_example(ex)
        #
        print('Evaluating ROUGE...')
        results_dict = rouge_functions.rouge_eval(ref_dir,
                                                  dec_dir,
                                                  l_param=l_param)
        # print("Results_dict: ", results_dict)
        rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix)

    util.print_execution_time(start_time)
コード例 #5
0
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.dataset_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    out_dir = os.path.join(
        os.path.expanduser('~') + '/data/kaiqiang_data', FLAGS.dataset_name)
    if FLAGS.mode == 'write':
        util.create_dirs(out_dir)
        if FLAGS.dataset_name == 'duc_2004':
            dataset_splits = ['test']
        elif FLAGS.dataset_split == 'all':
            dataset_splits = ['test', 'val', 'train']
        else:
            dataset_splits = [FLAGS.dataset_split]

        for dataset_split in dataset_splits:

            if dataset_split == 'test':
                ssi_data_path = os.path.join(
                    'logs/%s_bert_both_sentemb_artemb_plushidden' %
                    FLAGS.dataset_name, 'ssi.pkl')
                print(util.bcolors.OKGREEN +
                      "Loading SSI from BERT at %s" % ssi_data_path +
                      util.bcolors.ENDC)
                with open(ssi_data_path) as f:
                    ssi_triple_list = pickle.load(f)

            source_dir = os.path.join(data_dir, FLAGS.dataset_name)
            source_files = sorted(
                glob.glob(source_dir + '/' + dataset_split + '*'))

            total = len(source_files) * 1000 if (
                'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
                or 'xsum' in FLAGS.dataset_name) else len(source_files)
            example_generator = data.example_generator(
                source_dir + '/' + dataset_split + '*',
                True,
                False,
                should_check_valid=False)

            out_document_path = os.path.join(out_dir,
                                             dataset_split + '.Ndocument')
            out_summary_path = os.path.join(out_dir,
                                            dataset_split + '.Nsummary')
            out_example_idx_path = os.path.join(out_dir,
                                                dataset_split + '.Nexampleidx')

            doc_writer = open(out_document_path, 'w')
            if dataset_split != 'test':
                sum_writer = open(out_summary_path, 'w')
            ex_idx_writer = open(out_example_idx_path, 'w')

            for example_idx, example in enumerate(
                    tqdm(example_generator, total=total)):
                if FLAGS.num_instances != -1 and example_idx >= FLAGS.num_instances:
                    break
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
                if FLAGS.dataset_name == 'duc_2004':
                    groundtruth_summ_sents = [[
                        sent.strip()
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                else:
                    groundtruth_summ_sents = [[
                        sent.strip() for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]]
                if doc_indices is None:
                    doc_indices = [0] * len(
                        util.flatten_list_of_lists(article_sent_tokens))
                doc_indices = [int(doc_idx) for doc_idx in doc_indices]
                # rel_sent_indices, _, _ = preprocess_for_lambdamart_no_flags.get_rel_sent_indices(doc_indices, article_sent_tokens)

                if dataset_split == 'test':
                    if example_idx >= len(ssi_triple_list):
                        raise Exception(
                            'Len of ssi list (%d) is less than number of examples (>=%d)'
                            % (len(ssi_triple_list), example_idx))
                    ssi_length_extractive = ssi_triple_list[example_idx][2]
                    if ssi_length_extractive > 1:
                        a = 0
                    ssi = ssi_triple_list[example_idx][1]
                    ssi = ssi[:ssi_length_extractive]
                    groundtruth_similar_source_indices_list = ssi
                else:
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list,
                        FLAGS.sentence_limit)

                for ssi_idx, ssi in enumerate(
                        groundtruth_similar_source_indices_list):
                    if len(ssi) == 0:
                        continue
                    my_article = ' '.join(util.reorder(raw_article_sents, ssi))
                    doc_writer.write(my_article + '\n')
                    if dataset_split != 'test':
                        sum_writer.write(groundtruth_summ_sents[0][ssi_idx] +
                                         '\n')
                    ex_idx_writer.write(str(example_idx) + '\n')
    elif FLAGS.mode == 'evaluate':
        summary_dir = '/home/logan/data/kaiqiang_data/logan_ACL/trained_on_' + FLAGS.train_dataset + '/' + FLAGS.dataset_name
        out_summary_path = os.path.join(summary_dir, 'test' + 'Summary.txt')
        out_example_idx_path = os.path.join(out_dir, 'test' + '.Nexampleidx')
        decode_dir = 'logs/kaiqiang_%s_trainedon%s' % (FLAGS.dataset_name,
                                                       FLAGS.train_dataset)
        rouge_ref_dir = os.path.join(decode_dir, 'reference')
        rouge_dec_dir = os.path.join(decode_dir, 'decoded')
        util.create_dirs(rouge_ref_dir)
        util.create_dirs(rouge_dec_dir)

        def num_lines_in_file(file_path):
            with open(file_path) as f:
                num_lines = sum(1 for line in f)
            return num_lines

        def process_example(sents, ex_idx, groundtruth_summ_sents):
            final_decoded_words = []
            for sent in sents:
                final_decoded_words.extend(sent.split(' '))
            rouge_functions.write_for_rouge(groundtruth_summ_sents,
                                            None,
                                            ex_idx,
                                            rouge_ref_dir,
                                            rouge_dec_dir,
                                            decoded_words=final_decoded_words,
                                            log=False)

        num_lines_summary = num_lines_in_file(out_summary_path)
        num_lines_example_indices = num_lines_in_file(out_example_idx_path)
        if num_lines_summary != num_lines_example_indices:
            raise Exception(
                'Num lines summary != num lines example indices: (%d, %d)' %
                (num_lines_summary, num_lines_example_indices))

        source_dir = os.path.join(data_dir, FLAGS.dataset_name)
        example_generator = data.example_generator(source_dir + '/' + 'test' +
                                                   '*',
                                                   True,
                                                   False,
                                                   should_check_valid=False)

        sum_writer = open(out_summary_path)
        ex_idx_writer = open(out_example_idx_path)
        prev_ex_idx = 0
        sents = []

        for line_idx in tqdm(range(num_lines_summary)):
            line = sum_writer.readline()
            ex_idx = int(ex_idx_writer.readline())

            if ex_idx == prev_ex_idx:
                sents.append(line)
            else:
                example = example_generator.next()
                raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
                    example, names_to_types)
                if FLAGS.dataset_name == 'duc_2004':
                    groundtruth_summ_sents = [[
                        sent.strip()
                        for sent in gt_summ_text.strip().split('\n')
                    ] for gt_summ_text in groundtruth_summary_text]
                else:
                    groundtruth_summ_sents = [[
                        sent.strip() for sent in
                        groundtruth_summary_text.strip().split('\n')
                    ]]
                process_example(sents, ex_idx, groundtruth_summ_sents)
                prev_ex_idx = ex_idx
                sents = [line]

        example = example_generator.next()
        raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, doc_indices = util.unpack_tf_example(
            example, names_to_types)
        if FLAGS.dataset_name == 'duc_2004':
            groundtruth_summ_sents = [[
                sent.strip() for sent in gt_summ_text.strip().split('\n')
            ] for gt_summ_text in groundtruth_summary_text]
        else:
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
        process_example(sents, ex_idx, groundtruth_summ_sents)

        print("Now starting ROUGE eval...")
        if FLAGS.dataset_name == 'xsum':
            l_param = 100
        else:
            l_param = 100
        results_dict = rouge_functions.rouge_eval(rouge_ref_dir,
                                                  rouge_dec_dir,
                                                  l_param=l_param)
        rouge_functions.rouge_log(results_dict, decode_dir)

    else:
        raise Exception('mode flag was not evaluate or write.')
コード例 #6
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        total = len(glob.glob(self._batcher._data_path)) * 1000
        pbar = tqdm(total=total)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                logging.info("Output has been saved in %s and %s.",
                             self._rouge_ref_dir, self._rouge_dec_dir)
                if len(os.listdir(self._rouge_ref_dir)) != 0:
                    logging.info("Now starting ROUGE eval...")
                    results_dict = rouge_functions.rouge_eval(
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    rouge_functions.rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            all_original_abstract_sents = batch.all_original_abstracts_sents[0]
            raw_article_sents = batch.raw_article_sents[0]

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            decoded_words, decoded_output, best_hyp = decode_example(
                self._sess, self._model, self._vocab, batch, counter,
                self._batcher._hps)

            if FLAGS.single_pass:
                if counter < 1000:
                    self.write_for_human(raw_article_sents,
                                         all_original_abstract_sents,
                                         decoded_words, counter)
                rouge_functions.write_for_rouge(
                    all_original_abstract_sents,
                    None,
                    counter,
                    self._rouge_ref_dir,
                    self._rouge_dec_dir,
                    decoded_words=decoded_words
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.attn_vis:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, counter
                    )  # write info to .json file for visualization tool

                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
            pbar.update(1)
        pbar.close()
コード例 #7
0
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                sys_ssi = util.replace_empty_ssis(sys_ssi, raw_article_sents)
            else:
                gt_ssi, sys_ssi, ext_len = ssi_list[example_idx]
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                if gt_ssi != groundtruth_similar_source_indices_list:
                    print(
                        'Warning: Example %d has different groundtruth source indices: '
                        + str(groundtruth_similar_source_indices_list) +
                        ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = ''
            for ssi_idx, ssi in enumerate(sys_ssi):
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     FLAGS.batch_size, hps, self._vocab)

                decoded_words, decoded_output, best_hyp = decode_example(
                    self._sess, self._model, self._vocab, batch, example_idx,
                    hps)
                best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 1000:
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += '<u>System Summary</u><br><br>' + highlighted_html + '<br><br>'

                if len(final_decoded_words) >= 100:
                    break

            if example_idx < 1000:
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
コード例 #8
0
    def decode_iteratively(self, example_generator, total, names_to_types,
                           ssi_list, hps):
        attn_vis_idx = 0
        for example_idx, example in enumerate(
                tqdm(example_generator, total=total)):
            raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs, groundtruth_article_lcs_paths_list = util.unpack_tf_example(
                example, names_to_types)
            article_sent_tokens = [
                util.process_sent(sent) for sent in raw_article_sents
            ]
            groundtruth_summ_sents = [[
                sent.strip()
                for sent in groundtruth_summary_text.strip().split('\n')
            ]]
            groundtruth_summ_sent_tokens = [
                sent.split(' ') for sent in groundtruth_summ_sents[0]
            ]

            if ssi_list is None:  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
                sys_ssi = groundtruth_similar_source_indices_list
                sys_alp_list = groundtruth_article_lcs_paths_list
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                sys_ssi, sys_alp_list = util.replace_empty_ssis(
                    sys_ssi, raw_article_sents, sys_alp_list=sys_alp_list)
            else:
                gt_ssi, sys_ssi, ext_len, sys_token_probs_list = ssi_list[
                    example_idx]
                sys_alp_list = ssi_functions.list_labels_from_probs(
                    sys_token_probs_list, FLAGS.tag_threshold)
                if FLAGS.singles_and_pairs == 'singles':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 1)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 1)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 1)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 1)
                elif FLAGS.singles_and_pairs == 'both':
                    sys_ssi = util.enforce_sentence_limit(sys_ssi, 2)
                    sys_alp_list = util.enforce_sentence_limit(sys_alp_list, 2)
                    groundtruth_similar_source_indices_list = util.enforce_sentence_limit(
                        groundtruth_similar_source_indices_list, 2)
                    gt_ssi = util.enforce_sentence_limit(gt_ssi, 2)
                # if gt_ssi != groundtruth_similar_source_indices_list:
                #     raise Exception('Example %d has different groundtruth source indices: ' + str(groundtruth_similar_source_indices_list) + ' || ' + str(gt_ssi))
                if FLAGS.dataset_name == 'xsum':
                    sys_ssi = [sys_ssi[0]]

            final_decoded_words = []
            final_decoded_outpus = ''
            best_hyps = []
            highlight_html_total = '<u>System Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(sys_ssi):
                # selected_article_lcs_paths = None
                selected_article_lcs_paths = sys_alp_list[ssi_idx]
                ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                    ssi, selected_article_lcs_paths)
                selected_article_lcs_paths = [selected_article_lcs_paths]
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)
                selected_article_text = ' '.join([
                    ' '.join(sent)
                    for sent in util.reorder(article_sent_tokens, ssi)
                ])
                selected_doc_indices_str = '0 ' * len(
                    selected_article_text.split())
                if FLAGS.upper_bound:
                    selected_groundtruth_summ_sent = [[
                        groundtruth_summ_sents[0][ssi_idx]
                    ]]
                else:
                    selected_groundtruth_summ_sent = groundtruth_summ_sents

                batch = create_batch(selected_article_text,
                                     selected_groundtruth_summ_sent,
                                     selected_doc_indices_str,
                                     selected_raw_article_sents,
                                     selected_article_lcs_paths,
                                     FLAGS.batch_size, hps, self._vocab)

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string
                # article_withunks = data.show_art_oovs(original_article, self._vocab) # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

                if FLAGS.first_intact and ssi_idx == 0:
                    decoded_words = selected_article_text.strip().split()
                    decoded_output = selected_article_text
                else:
                    decoded_words, decoded_output, best_hyp = decode_example(
                        self._sess, self._model, self._vocab, batch,
                        example_idx, hps)
                    best_hyps.append(best_hyp)
                final_decoded_words.extend(decoded_words)
                final_decoded_outpus += decoded_output

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [decoded_words]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_total += highlighted_html + '<br>'

                if FLAGS.attn_vis and example_idx < 200:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, attn_vis_idx
                    )  # write info to .json file for visualization tool
                    attn_vis_idx += 1

                if len(final_decoded_words) >= 100:
                    break

            gt_ssi_list, gt_alp_list = util.replace_empty_ssis(
                groundtruth_similar_source_indices_list,
                raw_article_sents,
                sys_alp_list=groundtruth_article_lcs_paths_list)
            highlight_html_gt = '<u>Reference Summary</u><br><br>'
            for ssi_idx, ssi in enumerate(gt_ssi_list):
                selected_article_lcs_paths = gt_alp_list[ssi_idx]
                try:
                    ssi, selected_article_lcs_paths = util.make_ssi_chronological(
                        ssi, selected_article_lcs_paths)
                except:
                    util.print_vars(ssi, example_idx,
                                    selected_article_lcs_paths)
                    raise
                selected_raw_article_sents = util.reorder(
                    raw_article_sents, ssi)

                if example_idx < 100 or (example_idx >= 2000
                                         and example_idx < 2100):
                    min_matched_tokens = 2
                    selected_article_sent_tokens = [
                        util.process_sent(sent)
                        for sent in selected_raw_article_sents
                    ]
                    highlight_summary_sent_tokens = [
                        groundtruth_summ_sent_tokens[ssi_idx]
                    ]
                    highlight_ssi_list, lcs_paths_list, highlight_article_lcs_paths_list, highlight_smooth_article_lcs_paths_list = ssi_functions.get_simple_source_indices_list(
                        highlight_summary_sent_tokens,
                        selected_article_sent_tokens, None, 2,
                        min_matched_tokens)
                    highlighted_html = ssi_functions.html_highlight_sents_in_article(
                        highlight_summary_sent_tokens,
                        highlight_ssi_list,
                        selected_article_sent_tokens,
                        lcs_paths_list=lcs_paths_list,
                        article_lcs_paths_list=
                        highlight_smooth_article_lcs_paths_list)
                    highlight_html_gt += highlighted_html + '<br>'

            if example_idx < 100 or (example_idx >= 2000
                                     and example_idx < 2100):
                self.write_for_human(raw_article_sents, groundtruth_summ_sents,
                                     final_decoded_words, example_idx)
                highlight_html_total = ssi_functions.put_html_in_two_columns(
                    highlight_html_total, highlight_html_gt)
                ssi_functions.write_highlighted_html(highlight_html_total,
                                                     self._highlight_dir,
                                                     example_idx)

            # if example_idx % 100 == 0:
            #     attn_dir = os.path.join(self._decode_dir, 'attn_vis_data')
            #     attn_selections.process_attn_selections(attn_dir, self._decode_dir, self._vocab)

            rouge_functions.write_for_rouge(
                groundtruth_summ_sents,
                None,
                example_idx,
                self._rouge_ref_dir,
                self._rouge_dec_dir,
                decoded_words=final_decoded_words,
                log=False
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            # if FLAGS.attn_vis:
            #     self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens, example_idx) # write info to .json file for visualization tool
            example_idx += 1  # this is how many examples we've decoded

        logging.info("Decoder has finished reading dataset for single_pass.")
        logging.info("Output has been saved in %s and %s.",
                     self._rouge_ref_dir, self._rouge_dec_dir)
        if len(os.listdir(self._rouge_ref_dir)) != 0:
            if FLAGS.dataset_name == 'xsum':
                l_param = 100
            else:
                l_param = 100
            logging.info("Now starting ROUGE eval...")
            results_dict = rouge_functions.rouge_eval(self._rouge_ref_dir,
                                                      self._rouge_dec_dir,
                                                      l_param=l_param)
            rouge_functions.rouge_log(results_dict, self._decode_dir)
コード例 #9
0
def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.exp_name == 'extractive':
        for summary_method in summary_methods:
            if not os.path.exists(
                    os.path.join(out_dir, summary_method, 'decoded')):
                os.makedirs(os.path.join(out_dir, summary_method, 'decoded'))
            if not os.path.exists(
                    os.path.join(out_dir, summary_method, 'reference')):
                os.makedirs(os.path.join(out_dir, summary_method, 'reference'))
            print((os.path.join(out_dir, summary_method)))
            method_dir = os.path.join(summaries_dir, summary_method)
            file_names = sorted(
                [name for name in os.listdir(method_dir) if name[0] == 'd'])
            for art_idx, article_name in enumerate(tqdm(file_names)):
                file = os.path.join(method_dir, article_name)
                with open(file, 'rb') as f:
                    lines = f.readlines()
                tokenized_sents = [[
                    token.lower()
                    for token in nltk.tokenize.word_tokenize(line)
                ] for line in lines]
                sentences = [' '.join(sent) for sent in tokenized_sents]
                processed_summary = '\n'.join(sentences)
                out_name = '%06d_decoded.txt' % art_idx
                with open(
                        os.path.join(out_dir, summary_method, 'decoded',
                                     out_name), 'wb') as f:
                    f.write(processed_summary)

                reference_files = glob.glob(
                    os.path.join(ref_dir, '%06d' % art_idx + '*'))
                abstract_sentences = []
                for ref_file in reference_files:
                    with open(ref_file) as f:
                        lines = f.readlines()
                    abstract_sentences.append(lines)
                rouge_functions.write_for_rouge(
                    abstract_sentences, sentences, art_idx,
                    os.path.join(out_dir, summary_method, 'reference'),
                    os.path.join(out_dir, summary_method, 'decoded'))

            results_dict = rouge_functions.rouge_eval(
                ref_dir, os.path.join(out_dir, summary_method, 'decoded'))
            # print("Results_dict: ", results_dict)
            rouge_functions.rouge_log(results_dict,
                                      os.path.join(out_dir, summary_method))

        for summary_method in summary_methods:
            print(summary_method)
        all_results = ''
        for summary_method in summary_methods:
            sheet_results_file = os.path.join(out_dir, summary_method,
                                              "sheets_results.txt")
            with open(sheet_results_file) as f:
                results = f.read()
            all_results += results
        print(all_results)
        a = 0

    else:
        # source_dir = os.path.join(data_dir, FLAGS.dataset)
        log_root = os.path.join('logs', FLAGS.exp_name)
        ckpt_folder = util.find_largest_ckpt_folder(log_root)
        print(ckpt_folder)
        # if os.path.exists(os.path.join(log_dir,FLAGS.exp_name,ckpt_folder)):
        if ckpt_folder != 'decoded':
            summary_dir = os.path.join(log_dir, FLAGS.exp_name, ckpt_folder)
        else:
            summary_dir = os.path.join(log_dir, FLAGS.exp_name)
        ref_dir = os.path.join(summary_dir, 'reference')
        dec_dir = os.path.join(summary_dir, 'decoded')
        summary_files = glob.glob(os.path.join(dec_dir, '*_decoded.txt'))
        # summary_files = glob.glob(os.path.join(log_dir + FLAGS.exp_name, 'test_*.txt.result.summary'))
        # if len(summary_files) > 0 and not os.path.exists(dec_dir):      # reformat files from extract + rewrite
        #     os.makedirs(ref_dir)
        #     os.makedirs(dec_dir)
        #     for summary_file in tqdm(summary_files):
        #         ex_index = extract_digits(os.path.basename(summary_file))
        #         new_file = os.path.join(dec_dir, "%06d_decoded.txt" % ex_index)
        #         shutil.copyfile(summary_file, new_file)
        #
        #     ref_files_to_copy = glob.glob(os.path.join(log_dir, FLAGS.dataset, ckpt_folder, 'reference', '*'))
        #     for file in tqdm(ref_files_to_copy):
        #         basename = os.path.basename(file)
        #         shutil.copyfile(file, os.path.join(ref_dir, basename))

        lengths = []
        for summary_file in tqdm(summary_files):
            with open(summary_file) as f:
                summary = f.read()
            length = len(summary.strip().split())
            lengths.append(length)
        print('Average summary length: %.2f' % np.mean(lengths))

        print('Evaluating on %d files' % len(os.listdir(dec_dir)))
        results_dict = rouge_functions.rouge_eval(ref_dir,
                                                  dec_dir,
                                                  l_param=FLAGS.l_param)
        rouge_functions.rouge_log(results_dict, summary_dir)
コード例 #10
0
def main(unused_argv):
    # def main(unused_argv):

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    print('Running statistics on %s' % exp_name)

    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))
    ex_sents = ['single .', 'sentence .']
    article_text = ' '.join(ex_sents)

    total = len(source_files) * 1000
    example_generator = data.example_generator(source_dir + '/' +
                                               dataset_split + '*',
                                               True,
                                               False,
                                               should_check_valid=False)

    qid_ssi_to_importances = rank_source_sents(temp_in_path, temp_out_path)
    qid_ssi_to_token_scores_and_mappings = get_token_scores_for_ssi(
        temp_in_path, file_path_seq, file_path_mappings)
    ex_gen = example_generator_extended(example_generator, total,
                                        qid_ssi_to_importances,
                                        qid_ssi_to_token_scores_and_mappings)
    print('Creating list')
    ex_list = [ex for ex in ex_gen]
    ssi_list = list(futures.map(evaluate_example, ex_list))

    # save ssi_list
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'wb') as f:
        pickle.dump(ssi_list, f)
    with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
        ssi_list = pickle.load(f)
    print('Evaluating BERT model F1 score...')
    suffix = util.all_sent_selection_eval(ssi_list)
    #
    # # for ex in tqdm(ex_list, total=total):
    # #     load_and_evaluate_example(ex)
    #
    print('Evaluating ROUGE...')
    results_dict = rouge_functions.rouge_eval(ref_dir,
                                              dec_dir,
                                              l_param=l_param)
    # print("Results_dict: ", results_dict)
    rouge_functions.rouge_log(results_dict, my_log_dir, suffix=suffix)

    ssis_restricted = [
        ssi_triple[1][:ssi_triple[2]] for ssi_triple in ssi_list
    ]
    ssi_lens = [
        len(source_indices)
        for source_indices in util.flatten_list_of_lists(ssis_restricted)
    ]
    # print ssi_lens
    num_singles = ssi_lens.count(1)
    num_pairs = ssi_lens.count(2)
    print(
        'Percent singles/pairs: %.2f %.2f' %
        (num_singles * 100. / len(ssi_lens), num_pairs * 100. / len(ssi_lens)))

    util.print_execution_time(start_time)