def get_importances(model, batch, enc_states, vocab, sess, hps): if FLAGS.pg_mmr: enc_sentences, enc_tokens = batch.tokenized_sents[ 0], batch.word_ids_sents[0] if FLAGS.importance_fn == 'oracle': human_tokens = get_tokens_for_human_summaries( batch, vocab) # list (of 4 human summaries) of list of token ids metric = 'recall' importances_hat = rouge_l_similarity(enc_tokens, human_tokens, vocab, metric=metric) elif FLAGS.importance_fn == 'svr': if FLAGS.importance_fn == 'svr': with open(os.path.join(FLAGS.actual_log_root, 'svr.pickle'), 'rb') as f: svr_model = cPickle.load(f) enc_sent_indices = importance_features.get_sent_indices( enc_sentences, batch.doc_indices[0]) sent_representations_separate = importance_features.get_separate_enc_states( model, sess, enc_sentences, vocab, hps) importances_hat = get_svr_importances( enc_states[0], enc_sentences, enc_sent_indices, svr_model, sent_representations_separate) elif FLAGS.importance_fn == 'tfidf': importances_hat = get_tfidf_importances(batch.raw_article_sents[0]) importances = util.special_squash(importances_hat) else: importances = None return importances
def calc_importance_features(self, data_path, hps, model_save_path, docs_desired): """Calculate sentence-level features and save as a dataset""" data_path_filter_name = os.path.basename(data_path) if 'train' in data_path_filter_name: data_split = 'train' elif 'val' in data_path_filter_name: data_split = 'val' elif 'test' in data_path_filter_name: data_split = 'test' else: data_split = 'feats' if 'cnn-dailymail' in data_path: inst_per_file = 1000 else: inst_per_file = 1 filelist = glob.glob(data_path) num_documents_desired = docs_desired pbar = tqdm(initial=0, total=num_documents_desired) instances = [] sentences = [] counter = 0 doc_counter = 0 file_counter = 0 while True: batch = self._batcher.next_batch() # 1 example repeated across batch if doc_counter >= num_documents_desired: save_path = os.path.join(model_save_path, data_split + '_%06d'%file_counter) with open(save_path, 'wb') as f: cPickle.dump(instances, f) print('Saved features at %s' % save_path) return if batch is None: # finished decoding dataset in single_pass mode raise Exception('We havent reached the num docs desired (%d), instead we reached (%d)' % (num_documents_desired, doc_counter)) batch_enc_states, _ = self._model.run_encoder(self._sess, batch) for batch_idx, enc_states in enumerate(batch_enc_states): art_oovs = batch.art_oovs[batch_idx] all_original_abstracts_sents = batch.all_original_abstracts_sents[batch_idx] tokenizer = Tokenizer('english') # List of lists of words enc_sentences, enc_tokens = batch.tokenized_sents[batch_idx], batch.word_ids_sents[batch_idx] enc_sent_indices = importance_features.get_sent_indices(enc_sentences, batch.doc_indices[batch_idx]) enc_sentences_str = [' '.join(sent) for sent in enc_sentences] sent_representations_separate = importance_features.get_separate_enc_states(self._model, self._sess, enc_sentences, self._vocab, hps) sent_indices = enc_sent_indices sent_reps = importance_features.get_importance_features_for_article( enc_states, enc_sentences, sent_indices, tokenizer, sent_representations_separate) y, y_hat = importance_features.get_ROUGE_Ls(art_oovs, all_original_abstracts_sents, self._vocab, enc_tokens) binary_y = importance_features.get_best_ROUGE_L_for_each_abs_sent(art_oovs, all_original_abstracts_sents, self._vocab, enc_tokens) for rep_idx, rep in enumerate(sent_reps): rep.y = y[rep_idx] rep.binary_y = binary_y[rep_idx] for rep_idx, rep in enumerate(sent_reps): # Keep all sentences with importance above threshold. All others will be kept with a probability of prob_to_keep if FLAGS.importance_fn == 'svr': instances.append(rep) sentences.append(sentences) counter += 1 # this is how many examples we've decoded doc_counter += len(batch_enc_states) pbar.update(len(batch_enc_states))