def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() while batch is not None: best_summary = self.beam_search(batch) output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.")
def forward(self, input, seq_lens, vocab, batchartoovs=None): strings = [] for examp in input: copy = examp.clone().cpu().numpy().astype(int) converted = data.outputids2words( copy, vocab, (batchartoovs if config.pointer_gen else None)) strings.append(converted) strings = batch_to_ids(strings).cuda() embedded = self.elmo_layer(strings)['elmo_representations'] embedded = embedded[0] #[batch size, max enc steps, 1024] #embedded = self.embedding(input) packed = pack_padded_sequence(embedded, seq_lens, batch_first=True) output, hidden = self.lstm(packed) encoder_outputs, _ = pad_packed_sequence( output, batch_first=True) # h dim = B x t_k x n encoder_outputs = encoder_outputs.contiguous() encoder_feature = encoder_outputs.view( -1, 2 * config.hidden_dim) # B * t_k x 2*hidden_dim encoder_feature = self.W_h(encoder_feature) return encoder_outputs, encoder_feature, hidden
def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words(output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec'%(counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir)
def decode(self): start = time.time() counter = 0 for batch in self.batches: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words write_results(decoded_words, counter, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() '''print("Decoder has finished reading dataset for single_pass.")
def train_one_batch(self, batch): article_oovs, enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) self.optimizer.zero_grad() enc_batch = [outputids2words(ids,self.vocab,article_oovs[i]) for i,ids in enumerate(enc_batch.numpy())] enc_batch_list = [] for words in enc_batch: temp_list = [] for w in words: l = ft_model.get_numpy_vector(w) temp_list.append(l) enc_batch_list.append(temp_list) enc_batch_list = torch.Tensor(enc_batch_list) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch_list, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing # for i, id in enumerate(y_t_1): # print (id) # myid2word(id, self.vocab, article_oovs[i]) y_t_1 = [myid2word(id,self.vocab, article_oovs[i]) for i, id in enumerate(y_t_1.numpy())] y_t_1 = torch.Tensor([ft_model.get_numpy_vector(w) for w in y_t_1]) final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses/dec_lens_var loss = torch.mean(batch_avg_loss) loss.backward() self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) self.optimizer.step() return loss.item()
def evaluate_batch(self, print_sents=False): self.setup_valid() batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] ref_sents = [] article_sents = [] rouge = Rouge() while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) #-----------------------Summarization---------------------------------------------------- with T.autograd.no_grad(): pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) if len(decoded_words) < 2: decoded_words = "xxx" else: decoded_words = " ".join(decoded_words) decoded_sents.append(decoded_words) abstract = batch.original_abstracts[i] article = batch.original_articles[i] ref_sents.append(abstract) article_sents.append(article) article_art_oovs = batch.art_oovs[i] #batch = self.batcher.next_batch() break load_file = self.opt.load_model # just a model name #if print_sents: # self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file) Batcher.article_summary = decoded_sents[0] Batcher.oovs = " ".join(article_art_oovs) # print('Article: ',article_sents[0], '\n==> Summary: [',decoded_sents[0],']\nOut of vocabulary: ', " ".join(article_art_oovs),'\nModel used: ', load_file) scores = 0 #rouge.get_scores(decoded_sents, ref_sents, avg = True) if self.opt.task == "test": print('Done.') #print(load_file, "scores:", scores) else: rouge_l = scores["rouge-l"]["f"]
def decode(self): lemm = pymystem3.Mystem() rouge = RougeCalculator(stopwords=True, lang=LangRU()) result_rouge = [0] * 6 batch = self.batcher.next_batch() iters = 0 while batch is not None: # Run beam search to get best Hypothesis with torch.no_grad(): best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] original_text = batch.original_articles article_oov = batch.art_oovs[0] if batch.art_oovs else None batch = self.batcher.next_batch() original_abstract_sents = self.restore_text( original_abstract_sents) decoded_words_restore = self.restore_text(decoded_words) decoded_words = " ".join(decoded_words) print(f"original_abstract : {original_abstract_sents}") print(f"original_text : {original_text}") print(f"decoded_words : {decoded_words_restore}") print( f"decoded_words_oov : {show_abs_oovs(decoded_words, self.vocab, article_oov)}" ) cur_rouge = calk_rouge(original_abstract_sents, [decoded_words_restore], rouge, lemm) result_rouge = list( map(lambda x: x[0] + x[1], zip(result_rouge, cur_rouge))) iters += 1 print("--" * 100) print("RESULT METRICS") result_rouge = [i / iters for i in result_rouge] print_results(result_rouge) print("++++" * 100)
def evaluate_batch(self, print_sents=False): self.setup_valid() batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] ref_sents = [] article_sents = [] rouge = Rouge() while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) #-----------------------Summarization---------------------------------------------------- with T.autograd.no_grad(): pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) if len(decoded_words) < 2: decoded_words = "xxx" else: decoded_words = " ".join(decoded_words) decoded_sents.append(decoded_words) abstract = batch.original_abstracts[i] article = batch.original_articles[i] ref_sents.append(abstract) article_sents.append(article) batch = self.batcher.next_batch() load_file = self.opt.load_model if print_sents: self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file) scores = rouge.get_scores(decoded_sents, ref_sents, avg=True) if self.opt.task == "test": print(load_file, "scores:", scores) else: rouge_l = scores["rouge-l"]["f"] print(load_file, "rouge_l:", "%.4f" % rouge_l) with open("test_rg.txt", "a") as f: f.write("\n" + load_file + " - rouge_l: " + str(rouge_l)) f.close()
def evaluate_batch(self, article): self.setup_valid() batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] ref_sents = [] article_sents = [] rouge = Rouge() while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) #-----------------------Summarization---------------------------------------------------- with T.autograd.no_grad(): pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) if len(decoded_words) < 2: decoded_words = "xxx" else: decoded_words = " ".join(decoded_words) decoded_sents.append(decoded_words) abstract = batch.original_abstracts[i] article = batch.original_articles[i] ref_sents.append(abstract) article_sents.append(article) batch = self.batcher.next_batch() load_file = self.opt.load_model if article: self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file) scores = rouge.get_scores(decoded_sents, ref_sents) rouge_1 = sum([x["rouge-1"]["f"] for x in scores]) / len(scores) rouge_2 = sum([x["rouge-2"]["f"] for x in scores]) / len(scores) rouge_l = sum([x["rouge-l"]["f"] for x in scores]) / len(scores) logger.info(load_file + " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" + "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l)
def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() #print(batch.enc_batch) while batch is not None: # Run beam search to get best Hypothesis enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = get_input_from_batch( batch, use_cuda) enc_batch = enc_batch[0:1, :] enc_padding_mask = enc_padding_mask[0:1, :] in_seq = enc_batch in_pos = self.get_pos_data(enc_padding_mask) #print("enc_padding_mask", enc_padding_mask) #print("Summarizing one batch...") batch_hyp, batch_scores = self.summarize_batch(in_seq, in_pos) # Extract the output ids from the hypothesis and convert back to words output_words = np.array(batch_hyp) output_words = output_words[:, 0, 1:] for i, out_sent in enumerate(output_words): decoded_words = data.outputids2words( out_sent, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) original_abstract_sents = batch.original_abstracts_sents[i] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir)
def evaluate(self, pred, art_oovs, abstract_sentences): """ :param prediction: [batch, N] :param text: [batch, N] :param summary: [batch, N] :return: """ batch_size = len(pred) for j in range(batch_size): # print(j,"----------------------",pred[j]) output_ids = [int(id) for id in pred[j]] decoded_words = outputids2words( output_ids, self.vocab, (art_oovs[j] if self.config.pointer_gen else None)) ''' if batch_size == 1 and not isinstance(art_oovs[j], list): print("art oovs: ",art_oovs) decoded_words = outputids2words(output_ids, self.vocab, (art_oovs if self.config.pointer_gen else None)) else: decoded_words = outputids2words(output_ids, self.vocab, (art_oovs[j] if self.config.pointer_gen else None)) ''' # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if len(decoded_words) < self.config.min_dec_steps: continue decoded_sents = [] while len(decoded_words) > 0: try: fst_period_idx = decoded_words.index(".") except ValueError: fst_period_idx = len(decoded_words) sent = decoded_words[:fst_period_idx + 1] decoded_words = decoded_words[fst_period_idx + 1:] decoded_sents.append(' '.join(sent)) self.prediction.append("\n".join( [make_html_safe(sent) for sent in decoded_sents])) self.referece.append("\n".join( [make_html_safe(sent) for sent in abstract_sentences[j]])) '''
def evaluate_batch(self): batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] ref_sents = [] article_sents = [] rouge = Rouge() while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) print('Summarizing Batch...') #-----------------------Summarization---------------------------------------------------- with T.autograd.no_grad(): pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id, self.vocab.size()) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) decoded_words = " ".join(decoded_words) decoded_sents.append(decoded_words) abstract = batch.original_abstracts[i] article = batch.original_articles[i] ref_sents.append(abstract) article_sents.append(article) batch = self.batcher.next_batch() load_file = self.opt.load_model # if print_sents: # self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file) # scores = rouge.get_scores(decoded_sents, ref_sents, avg = True) return decoded_sents, ref_sents, article_sents # , scores
def validate_batch(self): self.setup_valid() batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] original_sents = [] rouge = Rouge() while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t_1 = get_enc_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) with T.autograd.no_grad(): pred_ids = beam_search_on_batch(enc_hidden, enc_out, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) if len(decoded_words) < 2: decoded_words = "xxx" else: decoded_words = " ".join(decoded_words) decoded_sents.append(decoded_words) tar = batch.original_abstracts[i] original_sents.append(tar) batch = self.batcher.next_batch() load_file = config.load_model_path.split("/")[-1] scores = rouge.get_scores(decoded_sents, original_sents, avg=True) rouge_l = scores["rouge-l"]["f"] print(load_file, "rouge_l:", "%.4f" % rouge_l)
def decode(self, file_id_start, file_id_stop): if file_id_stop > MAX_TEST_ID: file_id_stop = MAX_TEST_ID # while batch is not None: # do this for faster stack CPU machines - to replace those that fail!! idx_list = [i for i in range(file_id_start, file_id_stop)] random.shuffle(idx_list) for idx in idx_list: # check if this is written already if self.if_already_exists(idx): # print("ID {} already exists".format(idx)) continue # batch = self.batcher.next_batch() batch = self.batches[idx] # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, idx, self._rouge_ref_dir, self._rouge_dec_dir) print("decoded idx = {}".format(idx)) print("Finished decoding idx [{},{})".format(file_id_start, file_id_stop))
def abstract(self, article): start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) example = Example(' '.join(jieba.cut(article)), '', self.vocab) batch = Batch([example], self.vocab, 1) enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) decoded_words = " ".join(decoded_words) return decoded_words
def evaluate_batch(self): batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] ref_sents = [] task_sents = [] context_sents = [] while batch is not None: enc_batch, enc_seg_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_seg_data( batch) with T.autograd.no_grad(): enc_batch = self.model.embeds( enc_batch) #Get embeddings for encoder input enc_seg_batch = self.model.seg_embeds(enc_seg_batch) enc_batch = T.cat([enc_batch, enc_seg_batch], dim=2) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) print('Summarizing Batch...') #-----------------------Summarization---------------------------------------------------- with T.autograd.no_grad(): pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id, self.vocab.size()) for i in range(len(pred_ids)): decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) decoded_sents.append(" ".join(decoded_words)) ref_sents.append(batch.original_abstracts[i]) task_sents.append(batch.original_tasks[i]) context_sents.append(batch.original_contexts[i]) batch = self.batcher.next_batch() return decoded_sents, ref_sents, task_sents, context_sents
def decode(self, file_id_start, file_id_stop, ami_id='191209'): print("AMI transcription:", ami_id) test_data = load_ami_data(ami_id, 'test') # do this for faster stack CPU machines - to replace those that fail!! idx_list = [i for i in range(file_id_start, file_id_stop)] random.shuffle(idx_list) for idx in idx_list: # for idx in range(file_id_start, file_id_stop): # check if this is written already if self.if_already_exists(idx): print("ID {} already exists".format(idx)) continue # Run beam search to get best Hypothesis best_summary, art_oovs = self.beam_search(test_data, idx) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words # original_abstract_sents = batch.original_abstracts_sents[0] original_abstract_sents = [] write_for_rouge(original_abstract_sents, decoded_words, idx, self._rouge_ref_dir, self._rouge_dec_dir) print("decoded idx = {}".format(idx)) print("Finished decoding idx [{},{})".format(file_id_start, file_id_stop))
def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() # 新的架构里写在训练的decode部分 while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.MARK_EOS) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] original_article = batch.original_articles[0] # 英文 # write_for_rouge(original_abstract_sents, decoded_words, counter, # self._rouge_ref_dir, self._rouge_dec_dir) # 中文 self.write_result(original_article, original_abstract_sents, decoded_words, counter) counter += 1 # if counter % 1000 == 0: # print('%d example in %d sec'%(counter, time.time() - start)) # start = time.time() batch = self.batcher.next_batch()
def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() summaries = [] while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] ### TODO Needs to be fixed decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words # original_abstract_sents = batch.original_abstracts_sents[0] summaries.append(decoded_words) # summaries += output_ids # write_for_rouge(original_abstract_sents, decoded_words, counter, # self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() return summaries
def test_calc(self, article): example = batcher.Example(article, [], self.vocab) batch = batcher.Batch([example for _ in range(config.beam_size)], self.vocab, config.beam_size) with torch.no_grad(): best_summary = self.beam_search(batch) output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) article_restore = self.restore_text( batch.original_articles[-1].split()) decoded_words_restore = self.restore_text(decoded_words).replace( "[STOP]", "") print(f"original_text : {article_restore}") print(f"decoded_words : {decoded_words_restore}") decoded_words = " ".join(decoded_words) print( f"decoded_words_oov : {show_abs_oovs(decoded_words, self.vocab, batch.art_oovs[0] if batch.art_oovs else None)}" )
def eval_one_batch(self, batch): batch_size = batch.batch_size enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) step_losses = [] output_ids = [] y_t_1 = torch.ones(batch_size, dtype=torch.long) * self.vocab.word2id( data.START_DECODING) if config.use_gpu: y_t_1 = y_t_1.cuda() for _ in range(batch_size): output_ids.append([]) step_losses.append([]) for di in range(min(max_dec_len, config.max_dec_steps)): #y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) #NLL if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask # Move on to the next token _, idx = torch.max(final_dist, 1) idx = idx.reshape(batch_size, -1).squeeze() y_t_1 = idx for i, pred in enumerate(y_t_1): if not pred.item() == data.PAD_TOKEN: output_ids[i].append(pred.item()) for i, loss in enumerate(step_loss): step_losses[i].append(step_loss[i]) # Obtain the original and predicted summaries original_abstracts = batch.original_abstracts_sents predicted_abstracts = [ data.outputids2words(ids, self.vocab, None) for ids in output_ids ] # Compute the batched loss batched_losses = self.compute_batched_loss(step_losses, original_abstracts, predicted_abstracts) losses = torch.stack(batched_losses) losses = losses / dec_lens_var loss = torch.mean(losses) return loss.item()
def train_one_batch(self, batch): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) self.optimizer.zero_grad() encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) s_t_1_origin = s_t_1 batch_size = batch.batch_size step_losses = [] sample_idx = [] sample_log_probs = Variable(torch.zeros(batch_size)) baseline_idx = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing, shape [batch_size] final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) # sample if di == 0: # use decoder input[0], which is <BOS> sample_t_1 = dec_batch[:, di] s_t_sample = s_t_1_origin c_t_sample = Variable(torch.zeros((batch_size, 2 * config.hidden_dim))) final_dist, s_t_sample, c_t_sample, attn_dist, p_gen, next_coverage = self.model.decoder(sample_t_1, s_t_sample, encoder_outputs, encoder_feature, enc_padding_mask, c_t_sample, extra_zeros, enc_batch_extend_vocab, coverage, di) # according to final_dist to sample # change sample_t_1 dist = torch.distributions.Categorical(final_dist) sample_t_1 = Variable(dist.sample()) # record sample idx sample_idx.append(sample_t_1) # tensor list # compute sample probability sample_log_probs += torch.log( final_dist.gather(1, sample_t_1.view(-1, 1))) # gather value along axis=1. given index # baseline if di == 0: # use decoder input[0], which is <BOS> baseline_t_1 = dec_batch[:, di] s_t_sample = s_t_1_origin c_t_sample = Variable(torch.zeros((batch_size, 2 * config.hidden_dim))) final_dist, s_t_baseline, c_t_baseline, attn_dist, p_gen, next_coverage = self.model.decoder(baseline_t_1, s_t_baseline, encoder_outputs, encoder_feature, enc_padding_mask, c_t_baseline, extra_zeros, enc_batch_extend_vocab, coverage, di) # according to final_dist to get baseline # change baseline_t_1 baseline_t_1 = torch.autograd.Variable(final_dist.max(1)) # get max value along axis=1 # record baseline probability baseline_idx.append(baseline_t_1) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss = torch.mean(batch_avg_loss) # according to sample_idx and baseline_idx to compute RL loss # map sample/baseline_idx to string # compute rouge score # compute loss sample_idx = torch.stack(sample_idx, dim=1).squeeze() # expect shape (batch_size, seq_len) baseline_idx = torch.stack(baseline_idx, dim=1).squeeze() rl_loss = torch.zeros(batch_size) for i in range(sample_idx.shape[0]): # each example in a batch sample_y = data.outputids2words(sample_idx[i], self.vocab, (batch.art_oovs[i] if config.pointer_gen else None)) baseline_y = data.outputids2words(baseline_idx[i], self.vocab, (batch.art_oovs[i] if config.pointer_gen else None)) true_y = batch.original_abstracts[i] sample_score = rouge_l_f(sample_y, true_y) baseline_score = rouge_l_f(baseline_y, true_y) sample_score = Variable(sample_score) baseline_score = Variable(baseline_score) rl_loss[i] = baseline_score - sample_score rl_loss = rl_loss * sample_log_probs gamma = 0.9984 loss = (1 - gamma) * loss + gamma * rl_loss loss.backward() self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) self.optimizer.step() return loss.item()
def evaluate_batch(self, print_sents =False): self.setup_valid() batch = self.batcher.next_batch() start_id = self.vocab.word2id(data.START_DECODING) end_id = self.vocab.word2id(data.STOP_DECODING) unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) decoded_sents = [] ref_sents = [] article_sents = [] rouge = Rouge() batch_number = 0 while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, \ extra_zeros, ct_e = get_enc_data(batch) with torch.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) with torch.no_grad(): pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id) for i in range(len(pred_ids)): # 返回的是一个 单词列表。 decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) if len(decoded_words) < 2: decoded_words = 'xxx' else: decoded_words = ' '.join(decoded_words) decoded_sents.append(decoded_words) summary = batch.original_summarys[i] article = batch.original_articles[i] ref_sents.append(summary) article_sents.append(article) batch = self.batcher.next_batch() batch_number += 1 if batch_number < 100: continue else: break load_file = self.opt.load_model if print_sents: self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file) scores = rouge.get_scores(decoded_sents, ref_sents, avg=True) if self.opt.task == 'test': print(load_file, 'scores:', scores) sys.stdout.flush() else: rouge_l = scores['rouge-l']['f'] print(load_file, 'rouge-l:', '%.4f' % rouge_l)
def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() decoded_result = [] refered_result = [] article_result = [] while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] article = batch.original_articles[0] #write_for_rouge(original_abstract_sents, decoded_words, counter, # self._rouge_ref_dir, self._rouge_dec_dir) decoded_sents = [] while len(decoded_words) > 0: try: fst_period_idx = decoded_words.index(".") except ValueError: fst_period_idx = len(decoded_words) sent = decoded_words[:fst_period_idx + 1] decoded_words = decoded_words[fst_period_idx + 1:] decoded_sents.append(' '.join(sent)) # pyrouge calls a perl script that puts the data into HTML files. # Therefore we need to make our output HTML safe. decoded_sents = [make_html_safe(w) for w in decoded_sents] reference_sents = [ make_html_safe(w) for w in original_abstract_sents ] decoded_result.append(' '.join(decoded_sents)) refered_result.append(' '.join(reference_sents)) article_result.append(article) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") load_file = self.model_path_name self.print_original_predicted(decoded_result, refered_result, article_result, load_file) rouge = Rouge() scores = rouge.get_scores(decoded_result, refered_result) rouge_1 = sum([x["rouge-1"]["f"] for x in scores]) / len(scores) rouge_2 = sum([x["rouge-2"]["f"] for x in scores]) / len(scores) rouge_l = sum([x["rouge-l"]["f"] for x in scores]) / len(scores) rouge_1_r = sum([x["rouge-1"]["r"] for x in scores]) / len(scores) rouge_2_r = sum([x["rouge-2"]["r"] for x in scores]) / len(scores) rouge_l_r = sum([x["rouge-l"]["r"] for x in scores]) / len(scores) rouge_1_p = sum([x["rouge-1"]["p"] for x in scores]) / len(scores) rouge_2_p = sum([x["rouge-2"]["p"] for x in scores]) / len(scores) rouge_l_p = sum([x["rouge-l"]["p"] for x in scores]) / len(scores) log_str = " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" + "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l log_str_r = " rouge_1_r:" + "%.4f" % rouge_1_r + " rouge_2_r:" + "%.4f" % rouge_2_r + " rouge_l_r:" + "%.4f" % rouge_l_r logger.info(load_file + " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" + "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l) log_str_p = " rouge_1_p:" + "%.4f" % rouge_1_p + " rouge_2_p:" + "%.4f" % rouge_2_p + " rouge_l_p:" + "%.4f" % rouge_l_p results_file = os.path.join(self._decode_dir, "ROUGE_results.txt") with open(results_file, "w") as f: f.write(log_str + '\n') f.write(log_str_r + '\n') f.write(log_str_p + '\n')
def trainRLStep(self,enc_output, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, article_oovs, type): s_t = enc_hidden x_t = torch.LongTensor(len(enc_output)).fill_(self.start_id).to(DEVICE) prev_s = None sum_exp = None inds = [] decoder_padding_mask = [] log_probs = [] mask = torch.LongTensor(len(enc_output)).fill_(1).to(DEVICE) for t in range(MAX_DEC_STEPS): x_t = self.model.embeds(x_t) probs, s_t, ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_output, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_exp, prev_s) if type == "sample": #根据概率产生sample multi_dist = Categorical(probs) # print(multi_dist) x_t = multi_dist.sample() # print(x_t.shape) log_prob = multi_dist.log_prob(x_t) log_probs.append(log_prob) else: #greedy sample _, x_t = torch.max(probs, dim=1) x_t = x_t.detach() inds.append(x_t) mask_t = torch.zeros(len(enc_output)).to(DEVICE) mask_t[mask == 1] = 1 mask[(mask == 1) + (x_t == self.end_id) == 2] = 0 decoder_padding_mask.append(mask_t) is_oov = (x_t >= VOCAB_SIZE).long() #判断是否有超限,若有则用UNK x_t = (1 - is_oov) * x_t + (is_oov) * self.unk_id inds = torch.stack(inds, dim=1) decoder_padding_mask = torch.stack(decoder_padding_mask, dim=1) if type == "sample": log_probs = torch.stack(log_probs, dim=1) #将pad的去除 log_probs = log_probs * decoder_padding_mask lens = torch.sum(decoder_padding_mask, dim=1) #对应公式15 logp log_probs = torch.sum(log_probs,dim=1) / lens # print(log_prob.shape) decoded_strs = [] #将output的id转换为word for i in range(len(enc_output)): id_list = inds[i].cpu().numpy() oovs = article_oovs[i] S = data.outputids2words(id_list, self.vocab, oovs) # Generate sentence corresponding to sampled words try: end_idx = S.index(data.STOP_DECODING) S = S[:end_idx] except ValueError: S = S if len(S) < 2: S = ["xxx"] S = " ".join(S) decoded_strs.append(S) return decoded_strs, log_probs
def train_batch_RL(self, enc_out, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, article_oovs, greedy): '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation Args :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size) :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size) :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf) :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism :param enc_batch_extend_vocab: Input batch that stores OOV ids :param article_oovs: Batch containing list of OOVs in each example :param greedy: If true, performs greedy based sampling, else performs multinomial sampling Returns: :decoded_strs: List of decoded sentences :log_probs: Log probabilities of sampled words ''' s_t = enc_hidden #Decoder hidden states x_t = get_cuda(T.LongTensor(len(enc_out)).fill_( self.start_id)) #Input to the decoder prev_s = None #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf) sum_temporal_srcs = None #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf) inds = [] #Stores sampled indices for each time step decoder_padding_mask = [] #Stores padding masks of generated samples log_probs = [] #Stores log probabilites of generated samples mask = get_cuda( T.LongTensor(len(enc_out)).fill_(1) ) #Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise for t in range(config.max_dec_steps): x_t = self.model.embeds(x_t) probs, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder( x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s) if greedy is False: multi_dist = Categorical(probs) x_t = multi_dist.sample() #perform multinomial sampling log_prob = multi_dist.log_prob(x_t) log_probs.append(log_prob) else: _, x_t = T.max(probs, dim=1) #perform greedy sampling x_t = x_t.detach() inds.append(x_t) mask_t = get_cuda(T.zeros( len(enc_out))) #Padding mask of batch for current time step mask_t[ mask == 1] = 1 #If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0 mask[ (mask == 1) + (x_t == self.end_id) == 2] = 0 #If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0 decoder_padding_mask.append(mask_t) is_oov = (x_t >= config.vocab_size ).long() #Mask indicating whether sampled word is OOV x_t = (1 - is_oov) * x_t + ( is_oov) * self.unk_id #Replace OOVs with [UNK] token inds = T.stack(inds, dim=1) decoder_padding_mask = T.stack(decoder_padding_mask, dim=1) if greedy is False: #If multinomial based sampling, compute log probabilites of sampled words log_probs = T.stack(log_probs, dim=1) log_probs = log_probs * decoder_padding_mask #Not considering sampled words with padding mask = 0 lens = T.sum(decoder_padding_mask, dim=1) #Length of sampled sentence log_probs = T.sum( log_probs, dim=1 ) / lens # (bs,) #compute normalizied log probability of a sentence decoded_strs = [] for i in range(len(enc_out)): id_list = inds[i].cpu().numpy() oovs = article_oovs[i] S = data.outputids2words( id_list, self.vocab, oovs) # Generate sentence corresponding to sampled words try: end_idx = S.index(data.STOP_DECODING) S = S[:end_idx] except ValueError: S = S if len( S ) < 2: #If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE S = ["xxx"] S = " ".join(S) decoded_strs.append(S) return decoded_strs, log_probs
def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step, vocab, batchartoovs): if not self.training and step == 0: h_decoder, c_decoder = s_t_1 s_t_hat = torch.cat((h_decoder.view( -1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim c_t, _, coverage_next = self.attention_network( s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage) coverage = coverage_next strings = [] for examp in y_t_1: copy = [examp.clone().cpu().item()] converted = data.outputids2words( copy, vocab, (batchartoovs if config.pointer_gen else None)) strings.append(converted) strings = batch_to_ids(strings).cuda() y_t_1_embd = self.elmo_layer( strings)['elmo_representations'][0].squeeze_() # only last layer #y_t_1_embd = self.embedding(y_t_1) x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1)) lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1) h_decoder, c_decoder = s_t s_t_hat = torch.cat((h_decoder.view( -1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim c_t, attn_dist, coverage_next = self.attention_network( s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage) if self.training or step > 0: coverage = coverage_next p_gen = None if config.pointer_gen: p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim) p_gen = self.p_gen_linear(p_gen_input) p_gen = F.sigmoid(p_gen) output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3 output = self.out1(output) # B x hidden_dim #output = F.relu(output) output = self.out2(output) # B x vocab_size vocab_dist = F.softmax(output, dim=1) if config.pointer_gen: vocab_dist_ = p_gen * vocab_dist attn_dist_ = (1 - p_gen) * attn_dist if extra_zeros is not None: vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1) final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_) else: final_dist = vocab_dist return final_dist, s_t, c_t, attn_dist, p_gen, coverage
def beam_search(self, batch): #batch should have only one example article_oovs, enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \ get_input_from_batch(batch, use_cuda) enc_batch = [ outputids2words(ids, self.vocab, article_oovs[i]) for i, ids in enumerate(enc_batch.numpy()) ] enc_batch_list = [] for words in enc_batch: temp_list = [] for w in words: l = ft_model.get_numpy_vector(w) temp_list.append(l) enc_batch_list.append(temp_list) enc_batch_list = torch.Tensor(enc_batch_list) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch_list, enc_lens) s_t_0 = self.model.reduce_state(encoder_hidden) dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [ Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context=c_t_0[0], coverage=(coverage_t_0[0] if config.is_coverage else None)) for _ in xrange(config.beam_size) ] results = [] steps = 0 while steps < config.max_dec_steps and len(results) < config.beam_size: latest_tokens = [h.latest_token for h in beams] latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ for t in latest_tokens] y_t_1 = Variable(torch.LongTensor(latest_tokens)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h = [] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if config.is_coverage: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) y_t_1 = [ myid2word(id, self.vocab, article_oovs[i]) for i, id in enumerate(y_t_1.numpy()) ] y_t_1 = torch.Tensor([ft_model.get_numpy_vector(w) for w in y_t_1]) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps) topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in xrange(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if config.is_coverage else None) for j in xrange(config.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == config.beam_size or len( results) == config.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return beams_sorted[0]
def test(self): # time.sleep(5) batcher = Batcher(TEST_DATA_PATH, self.vocab, mode='test',batch_size=BATCH_SIZE, single_pass=True) batch = batcher.next_batch() decoded_sents = [] ref_sents = [] article_sents = [] rouge = Rouge() count = 0 while batch is not None: enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = self.getEncData(batch) with torch.autograd.no_grad(): enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) with torch.autograd.no_grad(): pred_ids = self.beamSearch(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab) # print(len(pred_ids[0])) for i in range(len(pred_ids)): # print('t',pred_ids[i]) decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i]) # print(decoded_words) if len(decoded_words) < 2: decoded_words = "xxx" else: decoded_words = " ".join(decoded_words) decoded_sents.append(decoded_words) abstract = batch.original_abstracts[i] article = batch.original_articles[i] ref_sents.append(abstract) article_sents.append(article) # print(decoded_sents) batch = batcher.next_batch() scores = rouge.get_scores(decoded_sents, ref_sents, avg=True) #统计结果 if count == 1: k0_sum = scores[KEYS[0]] k1_sum = scores[KEYS[1]] k2_sum = scores[KEYS[2]] if count > 1: k0_sum = dict(Counter(Counter(k0_sum) + Counter(scores[KEYS[0]]))) k1_sum = dict(Counter(Counter(k1_sum) + Counter(scores[KEYS[1]]))) k2_sum = dict(Counter(Counter(k2_sum) + Counter(scores[KEYS[2]]))) if count == 10: break count += 1 # print(scores) print(KEYS[0], end=' ') for k in k0_sum: print(k,k0_sum[k] / count,end = ' ') print('\n') print(KEYS[1],end = ' ') for k in k1_sum: print(k,k1_sum[k] / count,end = ' ') print('\n') print(KEYS[2], end=' ') for k in k2_sum: print(k,k2_sum[k] / count,end = ' ') print('\n')
def train_one_batch_pg(self, batch): batch_size = batch.batch_size enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, config.use_gpu) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, config.use_gpu) self.optimizer.zero_grad() encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) step_losses = [] output_ids = [] # Begin with START symbol y_t_1 = torch.ones(batch_size, dtype=torch.long) * self.vocab.word2id( data.START_DECODING) if config.use_gpu: y_t_1 = y_t_1.cuda() for _ in range(batch_size): output_ids.append([]) step_losses.append([]) for di in range(min(max_dec_len, config.max_dec_steps)): #y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) # NLL step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask # Move on to next token _, idx = torch.max(final_dist, 1) idx = idx.reshape(batch_size, -1).squeeze() y_t_1 = idx for i, pred in enumerate(y_t_1): if not pred.item() == data.PAD_TOKEN: output_ids[i].append(pred.item()) for i, loss in enumerate(step_loss): step_losses[i].append(step_loss[i]) # Obtain the original and predicted summaries original_abstracts = batch.original_abstracts_sents predicted_abstracts = [ data.outputids2words(ids, self.vocab, None) for ids in output_ids ] # Compute the batched loss batched_losses = self.compute_batched_sentence_loss( step_losses, original_abstracts, predicted_abstracts) #batched_losses = Variable(batched_losses, requires_grad=True) losses = torch.stack(batched_losses) losses = losses / dec_lens_var loss = torch.mean(losses) loss.backward() self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) self.optimizer.step() return loss.item()