예제 #1
0
def get_importances(model, batch, enc_states, vocab, sess, hps):
    if FLAGS.pg_mmr:
        enc_sentences, enc_tokens = batch.tokenized_sents[
            0], batch.word_ids_sents[0]
        if FLAGS.importance_fn == 'oracle':
            human_tokens = get_tokens_for_human_summaries(
                batch,
                vocab)  # list (of 4 human summaries) of list of token ids
            metric = 'recall'
            importances_hat = rouge_l_similarity(enc_tokens,
                                                 human_tokens,
                                                 vocab,
                                                 metric=metric)
        elif FLAGS.importance_fn == 'svr':
            if FLAGS.importance_fn == 'svr':
                with open(os.path.join(FLAGS.actual_log_root, 'svr.pickle'),
                          'rb') as f:
                    svr_model = cPickle.load(f)
            enc_sent_indices = importance_features.get_sent_indices(
                enc_sentences, batch.doc_indices[0])
            sent_representations_separate = importance_features.get_separate_enc_states(
                model, sess, enc_sentences, vocab, hps)
            importances_hat = get_svr_importances(
                enc_states[0], enc_sentences, enc_sent_indices, svr_model,
                sent_representations_separate)
        elif FLAGS.importance_fn == 'tfidf':
            importances_hat = get_tfidf_importances(batch.raw_article_sents[0])
        importances = util.special_squash(importances_hat)
    else:
        importances = None
    return importances
예제 #2
0
    def calc_importance_features(self, data_path, hps, model_save_path, docs_desired):
        """Calculate sentence-level features and save as a dataset"""
        data_path_filter_name = os.path.basename(data_path)
        if 'train' in data_path_filter_name:
            data_split = 'train'
        elif 'val' in data_path_filter_name:
            data_split = 'val'
        elif 'test' in data_path_filter_name:
            data_split = 'test'
        else:
            data_split = 'feats'
        if 'cnn-dailymail' in data_path:
            inst_per_file = 1000
        else:
            inst_per_file = 1
        filelist = glob.glob(data_path)
        num_documents_desired = docs_desired
        pbar = tqdm(initial=0, total=num_documents_desired)

        instances = []
        sentences = []
        counter = 0
        doc_counter = 0
        file_counter = 0
        while True:
            batch = self._batcher.next_batch()	# 1 example repeated across batch
            if doc_counter >= num_documents_desired:
                save_path = os.path.join(model_save_path, data_split + '_%06d'%file_counter)
                with open(save_path, 'wb') as f:
                    cPickle.dump(instances, f)
                print('Saved features at %s' % save_path)
                return

            if batch is None: # finished decoding dataset in single_pass mode
                raise Exception('We havent reached the num docs desired (%d), instead we reached (%d)' % (num_documents_desired, doc_counter))


            batch_enc_states, _ = self._model.run_encoder(self._sess, batch)
            for batch_idx, enc_states in enumerate(batch_enc_states):
                art_oovs = batch.art_oovs[batch_idx]
                all_original_abstracts_sents = batch.all_original_abstracts_sents[batch_idx]

                tokenizer = Tokenizer('english')
                # List of lists of words
                enc_sentences, enc_tokens = batch.tokenized_sents[batch_idx], batch.word_ids_sents[batch_idx]
                enc_sent_indices = importance_features.get_sent_indices(enc_sentences, batch.doc_indices[batch_idx])
                enc_sentences_str = [' '.join(sent) for sent in enc_sentences]

                sent_representations_separate = importance_features.get_separate_enc_states(self._model, self._sess, enc_sentences, self._vocab, hps)

                sent_indices = enc_sent_indices
                sent_reps = importance_features.get_importance_features_for_article(
                    enc_states, enc_sentences, sent_indices, tokenizer, sent_representations_separate)
                y, y_hat = importance_features.get_ROUGE_Ls(art_oovs, all_original_abstracts_sents, self._vocab, enc_tokens)
                binary_y = importance_features.get_best_ROUGE_L_for_each_abs_sent(art_oovs, all_original_abstracts_sents, self._vocab, enc_tokens)
                for rep_idx, rep in enumerate(sent_reps):
                    rep.y = y[rep_idx]
                    rep.binary_y = binary_y[rep_idx]

                for rep_idx, rep in enumerate(sent_reps):
                    # Keep all sentences with importance above threshold. All others will be kept with a probability of prob_to_keep
                    if FLAGS.importance_fn == 'svr':
                        instances.append(rep)
                        sentences.append(sentences)
                        counter += 1 # this is how many examples we've decoded
            doc_counter += len(batch_enc_states)
            pbar.update(len(batch_enc_states))
def main(unused_argv):

    print('Running statistics on %s' % FLAGS.exp_name)

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.all_actions:
        FLAGS.sent_dataset = True
        FLAGS.ssi_dataset = True
        FLAGS.print_output = True
        FLAGS.highlight = True

    original_dataset_name = 'xsum' if 'xsum' in FLAGS.dataset_name else 'cnn_dm' if (
        'cnn_dm' in FLAGS.dataset_name
        or 'duc_2004' in FLAGS.dataset_name) else ''
    vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name,
                  FLAGS.vocab_size)  # create a vocabulary

    source_dir = os.path.join(data_dir, FLAGS.dataset_name)
    util.create_dirs(html_dir)

    if FLAGS.dataset_split == 'all':
        if FLAGS.dataset_name == 'duc_2004':
            dataset_splits = ['test']
        else:
            dataset_splits = ['test', 'val', 'train']
    else:
        dataset_splits = [FLAGS.dataset_split]
    for dataset_split in dataset_splits:
        source_files = sorted(glob.glob(source_dir + '/' + dataset_split +
                                        '*'))
        if FLAGS.exp_name == 'reference':
            # summary_dir = log_dir + default_exp_name + '/decode_test_' + str(max_enc_steps) + \
            #                 'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/reference'
            # summary_files = sorted(glob.glob(summary_dir + '/*_reference.A.txt'))
            summary_dir = source_dir
            summary_files = source_files
        else:
            if FLAGS.exp_name == 'cnn_dm':
                summary_dir = log_dir + FLAGS.exp_name + '/decode_test_400maxenc_4beam_35mindec_100maxdec_ckpt-238410/decoded'
            else:
                ckpt_folder = util.find_largest_ckpt_folder(log_dir +
                                                            FLAGS.exp_name)
                summary_dir = log_dir + FLAGS.exp_name + '/' + ckpt_folder + '/decoded'
                # summary_dir = log_dir + FLAGS.exp_name + '/decode_test_' + str(max_enc_steps) + \
                #             'maxenc_4beam_' + str(min_dec_steps) + 'mindec_' + str(max_dec_steps) + 'maxdec_ckpt-238410/decoded'
            summary_files = sorted(glob.glob(summary_dir + '/*'))
        if len(summary_files) == 0:
            raise Exception('No files found in %s' % summary_dir)
        example_generator = data.example_generator(source_dir + '/' +
                                                   dataset_split + '*',
                                                   True,
                                                   False,
                                                   is_original=True)
        pros = {
            'annotators': 'dcoref',
            'outputFormat': 'json',
            'timeout': '5000000'
        }
        all_merge_examples = []
        num_extracted_list = []
        distances = []
        relative_distances = []
        html_str = ''
        extracted_sents_in_article_html = ''
        name = FLAGS.dataset_name + '_' + FLAGS.exp_name
        if FLAGS.coreference_replacement:
            name += '_coref'
        highlight_file_name = os.path.join(
            html_dir, FLAGS.dataset_name + '_' + FLAGS.exp_name)
        if FLAGS.consider_stopwords:
            highlight_file_name += '_stopwords'
        if FLAGS.highlight:
            extracted_sents_in_article_html_file = open(
                highlight_file_name + '_extracted_sents.html', 'wb')
        if FLAGS.kaiqiang:
            kaiqiang_article_texts = []
            kaiqiang_abstract_texts = []
            util.create_dirs(kaiqiang_dir)
            kaiqiang_article_file = open(
                os.path.join(
                    kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split +
                    '_' + str(FLAGS.min_matched_tokens) + '_articles.txt'),
                'wb')
            kaiqiang_abstract_file = open(
                os.path.join(
                    kaiqiang_dir, FLAGS.dataset_name + '_' + dataset_split +
                    '_' + str(FLAGS.min_matched_tokens) + '_abstracts.txt'),
                'wb')
        if FLAGS.ssi_dataset:
            if FLAGS.tag_tokens:
                with_coref_and_ssi_dir = lambdamart_dir + '_and_tag_tokens'
            else:
                with_coref_and_ssi_dir = lambdamart_dir
            lambdamart_out_dir = os.path.join(with_coref_and_ssi_dir,
                                              FLAGS.dataset_name)
            if FLAGS.sentence_limit == 1:
                lambdamart_out_dir += '_singles'
            if FLAGS.consider_stopwords:
                lambdamart_out_dir += '_stopwords'
            lambdamart_out_full_dir = os.path.join(lambdamart_out_dir, 'all')
            util.create_dirs(lambdamart_out_full_dir)
            lambdamart_writer = open(
                os.path.join(lambdamart_out_full_dir, dataset_split + '.bin'),
                'wb')

        simple_similar_source_indices_list_plus_empty = []
        example_idx = -1
        instance_idx = 0
        total = len(source_files) * 1000 if (
            'cnn' in FLAGS.dataset_name or 'newsroom' in FLAGS.dataset_name
            or 'xsum' in FLAGS.dataset_name) else len(source_files)
        random_choices = None
        if FLAGS.randomize:
            if FLAGS.dataset_name == 'cnn_dm':
                list_order = np.random.permutation(11490)
                random_choices = list_order[:FLAGS.num_instances]
        for example in tqdm(example_generator, total=total):
            example_idx += 1
            if FLAGS.num_instances != -1 and instance_idx >= FLAGS.num_instances:
                break
            if random_choices is not None and example_idx not in random_choices:
                continue
        # for file_idx in tqdm(range(len(source_files))):
        #     example = get_tf_example(source_files[file_idx])
            article_text = example.features.feature[
                'article'].bytes_list.value[0].decode().lower()
            if FLAGS.exp_name == 'reference':
                summary_text, all_summary_texts = get_summary_from_example(
                    example)
            else:
                summary_text = get_summary_text(summary_files[example_idx])
            article_tokens = split_into_tokens(article_text)
            if 'raw_article_sents' in example.features.feature and len(
                    example.features.feature['raw_article_sents'].bytes_list.
                    value) > 0:
                raw_article_sents = example.features.feature[
                    'raw_article_sents'].bytes_list.value

                raw_article_sents = [
                    sent.decode() for sent in raw_article_sents
                    if sent.decode().strip() != ''
                ]
                article_sent_tokens = [
                    util.process_sent(sent, whitespace=True)
                    for sent in raw_article_sents
                ]
            else:
                # article_text = util.to_unicode(article_text)

                # sent_pros = {'annotators': 'ssplit', 'outputFormat': 'json', 'timeout': '5000000'}
                # sents_result_dict = nlp.annotate(str(article_text), properties=sent_pros)
                # article_sent_tokens = [[token['word'] for token in sent['tokens']] for sent in sents_result_dict['sentences']]

                raw_article_sents = nltk.tokenize.sent_tokenize(article_text)
                article_sent_tokens = [
                    util.process_sent(sent) for sent in raw_article_sents
                ]
            if FLAGS.top_n_sents != -1:
                article_sent_tokens = article_sent_tokens[:FLAGS.top_n_sents]
                raw_article_sents = raw_article_sents[:FLAGS.top_n_sents]
            article_sents = [' '.join(sent) for sent in article_sent_tokens]
            try:
                article_tokens_string = str(' '.join(article_sents))
            except:
                try:
                    article_tokens_string = str(' '.join(
                        [sent.decode('latin-1') for sent in article_sents]))
                except:
                    raise

            if len(article_sent_tokens) == 0:
                continue

            summary_sent_tokens = split_into_sent_tokens(summary_text)
            if 'doc_indices' in example.features.feature and len(
                    example.features.feature['doc_indices'].bytes_list.value
            ) > 0:
                doc_indices_str = example.features.feature[
                    'doc_indices'].bytes_list.value[0].decode()
                if '1' in doc_indices_str:
                    doc_indices = [
                        int(x) for x in doc_indices_str.strip().split()
                    ]
                    rel_sent_positions = importance_features.get_sent_indices(
                        article_sent_tokens, doc_indices)
                else:
                    num_tokens_total = sum(
                        [len(sent) for sent in article_sent_tokens])
                    rel_sent_positions = list(range(len(raw_article_sents)))
                    doc_indices = [0] * num_tokens_total

            else:
                rel_sent_positions = None
                doc_indices = None
                doc_indices_str = None
            if 'corefs' in example.features.feature and len(
                    example.features.feature['corefs'].bytes_list.value) > 0:
                corefs_str = example.features.feature[
                    'corefs'].bytes_list.value[0]
                corefs = json.loads(corefs_str)
            # summary_sent_tokens = limit_to_n_tokens(summary_sent_tokens, 100)

            similar_source_indices_list_plus_empty = []

            simple_similar_source_indices, lcs_paths_list, article_lcs_paths_list, smooth_article_paths_list = ssi_functions.get_simple_source_indices_list(
                summary_sent_tokens,
                article_sent_tokens,
                vocab,
                FLAGS.sentence_limit,
                FLAGS.min_matched_tokens,
                not FLAGS.consider_stopwords,
                lemmatize=FLAGS.lemmatize,
                multiple_ssi=FLAGS.multiple_ssi)

            article_paths_parameter = article_lcs_paths_list if FLAGS.tag_tokens else None
            article_paths_parameter = smooth_article_paths_list if FLAGS.smart_tags else article_paths_parameter
            restricted_source_indices = util.enforce_sentence_limit(
                simple_similar_source_indices, FLAGS.sentence_limit)
            for summ_sent_idx, summ_sent in enumerate(summary_sent_tokens):
                if FLAGS.sent_dataset:
                    if len(restricted_source_indices[summ_sent_idx]) == 0:
                        continue
                    merge_example = get_merge_example(
                        restricted_source_indices[summ_sent_idx],
                        article_sent_tokens, summ_sent, corefs,
                        article_paths_parameter[summ_sent_idx])
                    all_merge_examples.append(merge_example)

            simple_similar_source_indices_list_plus_empty.append(
                simple_similar_source_indices)
            if FLAGS.ssi_dataset:
                summary_text_to_save = [
                    s for s in all_summary_texts
                ] if FLAGS.dataset_name == 'duc_2004' else summary_text
                write_lambdamart_example(simple_similar_source_indices,
                                         raw_article_sents,
                                         summary_text_to_save, corefs_str,
                                         doc_indices_str,
                                         article_paths_parameter,
                                         lambdamart_writer)

            if FLAGS.highlight:
                highlight_article_lcs_paths_list = smooth_article_paths_list if FLAGS.smart_tags else article_lcs_paths_list
                # simple_ssi_plus_empty = [ [s[0] for s in sim_source_ind] for sim_source_ind in simple_similar_source_indices]
                extracted_sents_in_article_html = ssi_functions.html_highlight_sents_in_article(
                    summary_sent_tokens, simple_similar_source_indices,
                    article_sent_tokens, doc_indices, lcs_paths_list,
                    highlight_article_lcs_paths_list)
                extracted_sents_in_article_html_file.write(
                    extracted_sents_in_article_html.encode())
            a = 0

            instance_idx += 1

        if FLAGS.ssi_dataset:
            lambdamart_writer.close()
            if FLAGS.dataset_name == 'cnn_dm' or FLAGS.dataset_name == 'newsroom' or FLAGS.dataset_name == 'xsum':
                chunk_size = 1000
            else:
                chunk_size = 1
            util.chunk_file(dataset_split,
                            lambdamart_out_full_dir,
                            lambdamart_out_dir,
                            chunk_size=chunk_size)

        if FLAGS.sent_dataset:
            with_coref_dir = data_dir + '_and_tag_tokens' if FLAGS.tag_tokens else data_dir
            out_dir = os.path.join(with_coref_dir,
                                   FLAGS.dataset_name + '_sent')
            if FLAGS.sentence_limit == 1:
                out_dir += '_singles'
            if FLAGS.consider_stopwords:
                out_dir += '_stopwords'
            if FLAGS.coreference_replacement:
                out_dir += '_coref'
            if FLAGS.top_n_sents != -1:
                out_dir += '_n=' + str(FLAGS.top_n_sents)
            util.create_dirs(out_dir)
            convert_data.write_with_generator(iter(all_merge_examples),
                                              len(all_merge_examples), out_dir,
                                              dataset_split)

        if FLAGS.print_output:
            # html_str = FLAGS.dataset + ' | ' + FLAGS.exp_name + '<br><br><br>' + html_str
            # save_fusions_to_file(html_str)
            ssi_path = os.path.join(ssi_dir, FLAGS.dataset_name)
            if FLAGS.consider_stopwords:
                ssi_path += '_stopwords'
            util.create_dirs(ssi_path)
            if FLAGS.dataset_name == 'duc_2004' and FLAGS.abstract_idx != 0:
                abstract_idx_str = '_%d' % FLAGS.abstract_idx
            else:
                abstract_idx_str = ''
            with open(
                    os.path.join(
                        ssi_path,
                        dataset_split + '_ssi' + abstract_idx_str + '.pkl'),
                    'wb') as f:
                pickle.dump(simple_similar_source_indices_list_plus_empty, f)

        if FLAGS.kaiqiang:
            # kaiqiang_article_file.write('\n'.join(kaiqiang_article_texts))
            # kaiqiang_abstract_file.write('\n'.join(kaiqiang_abstract_texts))
            kaiqiang_article_file.close()
            kaiqiang_abstract_file.close()
        if FLAGS.highlight:
            extracted_sents_in_article_html_file.close()
        a = 0