def test(model, path, vocab): model.load_state_dict(torch.load(path)) model.eval() beam_search = LSTMBeamSearch(conf.beam_size, conf.vocab_size, conf.max_decode_len, model) batcher = Batcher(config.decode_data_path, vocab, mode='decode', batch_size=1, single_pass=True) counter = 0 batch = batcher.next_batch() while batch is not None: input_ids, input_mask, input_lens, extended_input_ids, extra_zeros = prepare_src_batch(batch) best_summary = beam_search.generate(input_ids, extended_input_ids, extra_zeros) output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = outputids2words(output_ids, vocab, batch.art_oovs[0]) try: fst_stop_idx = decoded_words.index(STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words write_for_rouge(batch.original_abstracts_sents[0], decoded_words, counter, conf.rouge_ref_dir, conf.rouge_dec_dir) batch = batcher.next_batch() counter += 1 results_dict = rouge_eval(conf.rouge_ref_dir, conf.rouge_dec_dir) rouge_log(results_dict, conf.decode_dir)
class BeamSearchDecoder: def __init__(self, model): self._decode_dir = os.path.join(config.log_root, 'decode_%s' % ("model_bert_coverage")) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = VocabBert(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', batch_size=config.beam_size, single_pass=True) self.model = model def beam_search(self, batch, conf): #batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros = helper.prepare_src_batch( batch) encoder_output, _ = self.model.encoder.forward( enc_batch, enc_padding_mask.squeeze(1)) hyps_list = [ Hypothesis(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0]) for _ in range(config.beam_size) ] results = [] steps = 0 yt = torch.zeros(config.beam_size, 1).long().to(device) while steps < config.max_dec_steps and len(results) < config.beam_size: latest_tokens = [h.latest_token for h in hyps_list] latest_tokens = [ t if t < self.vocab.size() else self.vocab.word2id( data.UNKNOWN_TOKEN) for t in latest_tokens ] curr_yt = torch.LongTensor(latest_tokens).unsqueeze(1).to( device) # [Bx1] yt = torch.cat((yt, curr_yt), dim=1) out, _ = self.model.decode( encoder_output, yt[:, 1:], enc_padding_mask, helper.subsequent_mask(yt[:, 1:].size(-1))) extra_zeros_ip = None if extra_zeros is not None: extra_zeros_ip = extra_zeros[:, 0:steps + 1, :] if config.coverage: op_dist, _ = self.model.generator(out, encoder_output, enc_padding_mask, enc_batch_extend_vocab, extra_zeros_ip) else: op_dist = self.model.generator(out, encoder_output, enc_padding_mask, enc_batch_extend_vocab, extra_zeros_ip) log_probs = op_dist[:, -1, :] topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2) all_hyps = [] num_orig_hyps = 1 if steps == 0 else len(hyps_list) for i in range(num_orig_hyps): h = hyps_list[i] for j in range(config.beam_size * 2): # for each of the top beam_size hyps: hyp = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item()) all_hyps.append(hyp) hyps_list = [] sorted_hyps = sorted(all_hyps, key=lambda h: h.avg_log_prob, reverse=True) for h in sorted_hyps: if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: hyps_list.append(h) if len(hyps_list) == config.beam_size or len( results) == config.beam_size: break steps += 1 if len(results) == 0: results = hyps_list results_sorted = sorted(results, key=lambda h: h.avg_log_prob, reverse=True) return results_sorted[0] def decode(self, conf): self.model.eval() start = time.time() counter = 0 batch = self.batcher.next_batch() article_list = list() i = 0 while batch is not None: i += 1 # Run beam search to get best Hypothesis best_summary = self.beam_search(batch, conf) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] # print(output_ids) decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if i % 100 == 0: print("Batch: {}".format(i)) print(decoded_words) original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir)
class Train(object): def __init__(self, args, model_name=None): self.args = args vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file, args) self.train_batcher = Batcher(args.train_data_path, self.vocab, mode='train', batch_size=args.batch_size, single_pass=False, args=args) self.eval_batcher = Batcher(args.eval_data_path, self.vocab, mode='eval', batch_size=args.batch_size, single_pass=True, args=args) time.sleep(30) if model_name is None: self.train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) else: self.train_dir = os.path.join(config.log_root, model_name) if not os.path.exists(self.train_dir): os.mkdir(self.train_dir) self.model_dir = os.path.join(self.train_dir, 'model') if not os.path.exists(self.model_dir): os.mkdir(self.model_dir) #self.summary_writer = SummaryWriter(train_dir) def save_model(self, running_avg_loss, iter, logger, best_val_loss): state = { 'iter': iter, 'best_val_loss': best_val_loss, 'encoder_state_dict': self.model.module.encoder.state_dict(), 'decoder_state_dict': self.model.module.decoder.state_dict(), 'reduce_state_dict': self.model.module.reduce_state.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_loss } model_save_path = os.path.join( self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) print(model_save_path) logger.debug(model_save_path) torch.save(state, model_save_path) if self.args.clear_old_checkpoints: self.clear_model_dir(checkpoints=self.args.keep_ckpts, logger=logger) def clear_model_dir(self, checkpoints, logger): """ Clears the model directory and only maintains the latest `checkpoints` number of checkpoints. """ files = os.listdir(self.model_dir) last_modification = [(os.path.getmtime(os.path.join(self.model_dir, f)), f) for f in files] # Sort the list by last modified. last_modification.sort(key=itemgetter(0)) # Delete everything but the last 10 files. ckpnt_no = 0 for time, f in last_modification[:-checkpoints]: ckpnt_no += 1 os.remove(os.path.join(self.model_dir, f)) msg = "Deleted %d checkpoints" % (ckpnt_no) logger.debug(msg) print(msg) def setup_train(self, args): self.model = nn.DataParallel(Model(args, self.vocab)).to(device) params = list(self.model.module.encoder.parameters()) + list(self.model.module.decoder.parameters()) + \ list(self.model.module.reduce_state.parameters()) initial_lr = args.lr_coverage if args.is_coverage else args.lr self.optimizer = AdagradCustom( params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc) self.crossentropy = nn.CrossEntropyLoss(ignore_index=-1) self.head_child_crossent = nn.CrossEntropyLoss(ignore_index=-1, weight=torch.Tensor( [0.1, 1]).cuda()) self.attn_mse_loss = nn.MSELoss() start_iter, start_loss = 0, 0 best_val_loss = None if args.reload_path is not None: print('Loading from checkpoint: ' + str(args.reload_path)) state = torch.load(args.reload_path, map_location=lambda storage, location: storage) start_iter = state['iter'] start_loss = state['current_loss'] #if 'best_val_loss' in state: # best_val_loss = state['best_val_loss'] if not args.is_coverage: self.optimizer.load_state_dict(state['optimizer']) if use_cuda: for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) return start_iter, start_loss, best_val_loss def setup_logging(self): logger = logging.getLogger() logger.setLevel(logging.DEBUG) filename = os.path.join(self.train_dir, 'train.log') ah = logging.FileHandler(filename) ah.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(message)s') ah.setFormatter(formatter) logger.addHandler(ah) return logger def train_one_batch(self, batch, args): self.optimizer.zero_grad() self.model.module.encoder.document_structure_att.output = None loss, _, _, _ = self.get_loss(batch, args) if loss is None: return None s1 = time.time() loss.backward() #print("time for backward: "+str(time.time() - s1)) clip_grad_norm(self.model.module.encoder.parameters(), config.max_grad_norm) clip_grad_norm(self.model.module.decoder.parameters(), config.max_grad_norm) clip_grad_norm(self.model.module.reduce_state.parameters(), config.max_grad_norm) self.optimizer.step() return loss.item() def train_iters(self, n_iters, args): start_iter, running_avg_loss, best_val_loss = self.setup_train(args) logger = self.setup_logging() logger.debug(str(args)) logger.debug(str(config)) start = time.time() # best_val_loss = None for it in tqdm(range(n_iters), dynamic_ncols=True): iter = start_iter + it self.model.module.train() batch = self.train_batcher.next_batch() start1 = time.time() loss = self.train_one_batch(batch, args) #print("time for 1 batch+get: "+str(time.time() - start)) #print("time for 1 batch: "+str(time.time() - start1)) #start=time.time() #print(loss) # for n,p in self.model.module.encoder.named_parameters(): # print('===========\ngradient:{}\n----------\n{}'.format(n,p.grad)) # exit() if math.isnan(loss): msg = "Loss has reached NAN. Exiting" logger.debug(msg) print(msg) exit() if loss is not None: running_avg_loss = calc_running_avg_loss( loss, running_avg_loss, iter) iter += 1 print_interval = 200 if iter % print_interval == 0: msg = 'steps %d, seconds for %d batch: %.2f , loss: %f' % ( iter, print_interval, time.time() - start, loss) print(msg) logger.debug(msg) start = time.time() if iter % config.eval_interval == 0: print("Starting Eval") loss = self.run_eval(logger, args) if best_val_loss is None or loss < best_val_loss: best_val_loss = loss self.save_model(running_avg_loss, iter, logger, best_val_loss) print("Saving best model") logger.debug("Saving best model") # print("Deleting older checkpoints") # ckpt_no = 0 # for f in sorted(os.listdir(self.model_dir))[:-10]: # ckpt_no +=1 # os.remove(f) # print("Deleted %d checkpoints" % (ckpt_no)) def get_loss(self, batch, args, mode='train'): s2 = time.time() dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, \ enc_batch_extend_vocab, extra_zeros, c_t_1, coverage, word_batch, word_padding_mask, enc_word_lens, \ enc_tags_batch, enc_sent_tags, enc_sent_token_mat, adj_mat, weighted_adj_mat, norm_adj_mat,\ parent_heads, undir_weighted_adj_mat = get_input_from_batch(batch, use_cuda, args) #print("time for input func: "+str(time.time() - s2)) final_dist_list, attn_dist_list, p_gen_list, coverage_list, sent_attention_matrix, \ sent_single_head_scores, sent_all_head_scores, sent_all_child_scores, \ token_score, sent_score, doc_score = self.model.forward(enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage, word_batch, word_padding_mask, enc_word_lens, enc_tags_batch, enc_sent_token_mat, max_dec_len, dec_batch, adj_mat, weighted_adj_mat, undir_weighted_adj_mat, args) step_losses = [] loss = 0 ind_losses = { 'summ_loss': 0, 'sent_single_head_loss': 0, 'sent_all_head_loss': 0, 'sent_all_child_loss': 0, 'token_contsel_loss': 0, 'sent_imp_loss': 0, 'doc_imp_loss': 0 } counts = { 'token_consel_num_correct': 0, 'token_consel_num': 0, 'sent_imp_num_correct': 0, 'doc_imp_num_correct': 0, 'sent_single_heads_num_correct': 0, 'sent_single_heads_num': 0, 'sent_all_heads_num_correct': 0, 'sent_all_heads_num': 0, 'sent_all_child_num_correct': 0, 'sent_all_child_num': 0 } eval_data = {} s1 = time.time() if args.use_summ_loss: for di in range(min(max_dec_len, args.max_dec_steps)): final_dist = final_dist_list[:, di, :] attn_dist = attn_dist_list[:, di, :] if args.is_coverage: coverage = coverage_list[:, di, :] target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if args.is_coverage: step_coverage_loss = torch.sum( torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss += torch.mean(batch_avg_loss) ind_losses['summ_loss'] += torch.mean(batch_avg_loss).item() if args.heuristic_chains: if args.use_attmat_loss: pred = sent_attention_matrix[:, :, 1:].contiguous().view(-1) gold = norm_adj_mat.view(-1) loss_aux = self.attn_mse_loss(pred, gold) loss += 100 * loss_aux if args.use_sent_single_head_loss: pred = sent_single_head_scores pred = pred.view(-1, pred.size(2)) head_labels = parent_heads.view(-1) loss_aux = self.crossentropy(pred, head_labels.long()) loss += loss_aux prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) if mode == 'eval': prediction[ head_labels == -1] = -2 # Explicitly set masked tokens as different from value in gold counts['sent_single_heads_num_correct'] = torch.sum( prediction.eq(head_labels.long())).item() counts['sent_single_heads_num'] = torch.sum( head_labels != -1).item() ind_losses['sent_single_head_loss'] += loss_aux.item() if args.use_sent_all_head_loss: pred = sent_all_head_scores pred = pred.view(-1, pred.size(3)) target_h = adj_mat.permute(0, 2, 1).contiguous().view(-1) #print(pred.size(), target.size()) loss_aux = self.head_child_crossent(pred, target_h.long()) loss += loss_aux prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) if mode == 'eval': prediction[ target_h == -1] = -2 # Explicitly set masked tokens as different from value in gold counts['sent_all_heads_num_correct'] = torch.sum( prediction.eq(target_h.long())).item() counts['sent_all_heads_num_correct_1'] = torch.sum( prediction[target_h == 1].eq( target_h[target_h == 1].long())).item() counts['sent_all_heads_num_correct_0'] = torch.sum( prediction[target_h == 0].eq( target_h[target_h == 0].long())).item() counts['sent_all_heads_num_1'] = torch.sum( target_h == 1).item() counts['sent_all_heads_num_0'] = torch.sum( target_h == 0).item() counts['sent_all_heads_num'] = torch.sum( target_h != -1).item() eval_data['sent_all_heads_pred'] = prediction.cpu().numpy() eval_data['sent_all_heads_true'] = target_h.cpu().numpy() ind_losses['sent_all_head_loss'] += loss_aux.item() #print('all head '+str(loss_aux.item())) if args.use_sent_all_child_loss: pred = sent_all_child_scores pred = pred.view(-1, pred.size(3)) target = adj_mat.contiguous().view(-1) loss_aux = self.head_child_crossent(pred, target.long()) loss += loss_aux prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) if mode == 'eval': prediction[ target == -1] = -2 # Explicitly set masked tokens as different from value in gold counts['sent_all_child_num_correct'] = torch.sum( prediction.eq(target.long())).item() counts['sent_all_child_num_correct_1'] = torch.sum( prediction[target == 1].eq( target[target == 1].long())).item() counts['sent_all_child_num_correct_0'] = torch.sum( prediction[target == 0].eq( target[target == 0].long())).item() counts['sent_all_child_num_1'] = torch.sum( target == 1).item() counts['sent_all_child_num_0'] = torch.sum( target == 0).item() counts['sent_all_child_num'] = torch.sum( target != -1).item() eval_data['sent_all_child_pred'] = prediction.cpu().numpy() eval_data['sent_all_child_true'] = target.cpu().numpy() ind_losses['sent_all_child_loss'] += loss_aux.item() #print('all child '+str(loss_aux.item())) # print(target_h.long().eq(target.long())) # print(adj_mat) #else: # pass if args.use_token_contsel_loss: pred = token_score.view(-1, 2) gold = enc_tags_batch.view(-1) loss1 = self.crossentropy(pred, gold.long()) loss += loss1 if mode == 'eval': prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) prediction[ gold == -1] = -2 # Explicitly set masked tokens as different from value in gold counts['token_consel_num_correct'] = torch.sum( prediction.eq(gold)).item() counts['token_consel_num'] = torch.sum(gold != -1).item() ind_losses['token_contsel_loss'] += loss1.item() if args.use_sent_imp_loss: pred = sent_score.view(-1) enc_sent_tags[enc_sent_tags == -1] = 0 gold = enc_sent_tags.sum(dim=-1).float() gold = gold / gold.sum(dim=1, keepdim=True).repeat(1, gold.size(1)) gold = gold.view(-1) loss2 = self.attn_mse_loss(pred, gold) ind_losses['sent_imp_loss'] += loss2.item() loss += loss2 if args.use_doc_imp_loss: pred = doc_score.view(-1) count_tags = enc_tags_batch.clone().detach() count_tags[count_tags == 0] = 1 count_tags[count_tags == -1] = 0 token_count = count_tags.sum(dim=-1).sum(dim=-1) enc_tags_batch[enc_tags_batch == -1] = 0 gold = enc_tags_batch.sum(dim=-1) gold = gold.sum(dim=-1) gold = gold / token_count loss3 = self.attn_mse_loss(pred, gold) loss += loss3 ind_losses['doc_imp_loss'] += loss3.item() #print("time for loss compute: "+str(time.time() - s1)) #print("time for 1 batch func: "+str(time.time() - s2)) return loss, ind_losses, counts, eval_data def run_eval(self, logger, args): running_avg_loss, iter = 0, 0 run_avg_losses = { 'summ_loss': 0, 'sent_single_head_loss': 0, 'sent_all_head_loss': 0, 'sent_all_child_loss': 0, 'token_contsel_loss': 0, 'sent_imp_loss': 0, 'doc_imp_loss': 0 } counts = { 'token_consel_num_correct': 0, 'token_consel_num': 0, 'sent_single_heads_num_correct': 0, 'sent_single_heads_num': 0, 'sent_all_heads_num_correct': 0, 'sent_all_heads_num': 0, 'sent_all_heads_num_correct_1': 0, 'sent_all_heads_num_1': 0, 'sent_all_heads_num_correct_0': 0, 'sent_all_heads_num_0': 0, 'sent_all_child_num_correct': 0, 'sent_all_child_num': 0, 'sent_all_child_num_correct_1': 0, 'sent_all_child_num_1': 0, 'sent_all_child_num_correct_0': 0, 'sent_all_child_num_0': 0 } eval_res = { 'sent_all_heads_pred': [], 'sent_all_heads_true': [], 'sent_all_child_pred': [], 'sent_all_child_true': [], } self.model.module.eval() self.eval_batcher._finished_reading = False self.eval_batcher.setup_queues() batch = self.eval_batcher.next_batch() while batch is not None: loss, sample_ind_losses, sample_counts, eval_data = self.get_loss( batch, args, mode='eval') loss = loss.item() if loss is not None: running_avg_loss = calc_running_avg_loss( loss, running_avg_loss, iter) if args.use_summ_loss: run_avg_losses['summ_loss'] = calc_running_avg_loss( sample_ind_losses['summ_loss'], run_avg_losses['summ_loss'], iter) if args.use_sent_single_head_loss: run_avg_losses[ 'sent_single_head_loss'] = calc_running_avg_loss( sample_ind_losses['sent_single_head_loss'], run_avg_losses['sent_single_head_loss'], iter) counts['sent_single_heads_num_correct'] += sample_counts[ 'sent_single_heads_num_correct'] counts['sent_single_heads_num'] += sample_counts[ 'sent_single_heads_num'] if args.use_sent_all_head_loss: run_avg_losses[ 'sent_all_head_loss'] = calc_running_avg_loss( sample_ind_losses['sent_all_head_loss'], run_avg_losses['sent_all_head_loss'], iter) counts['sent_all_heads_num_correct'] += sample_counts[ 'sent_all_heads_num_correct'] counts['sent_all_heads_num'] += sample_counts[ 'sent_all_heads_num'] counts['sent_all_heads_num_correct_1'] += sample_counts[ 'sent_all_heads_num_correct_1'] counts['sent_all_heads_num_1'] += sample_counts[ 'sent_all_heads_num_1'] counts['sent_all_heads_num_correct_0'] += sample_counts[ 'sent_all_heads_num_correct_0'] counts['sent_all_heads_num_0'] += sample_counts[ 'sent_all_heads_num_0'] eval_res['sent_all_heads_pred'].append( eval_data['sent_all_heads_pred']) eval_res['sent_all_heads_true'].append( eval_data['sent_all_heads_true']) if args.use_sent_all_child_loss: run_avg_losses[ 'sent_all_child_loss'] = calc_running_avg_loss( sample_ind_losses['sent_all_child_loss'], run_avg_losses['sent_all_child_loss'], iter) counts['sent_all_child_num_correct'] += sample_counts[ 'sent_all_child_num_correct'] counts['sent_all_child_num'] += sample_counts[ 'sent_all_child_num'] counts['sent_all_child_num_correct_1'] += sample_counts[ 'sent_all_child_num_correct_1'] counts['sent_all_child_num_1'] += sample_counts[ 'sent_all_child_num_1'] counts['sent_all_child_num_correct_0'] += sample_counts[ 'sent_all_child_num_correct_0'] counts['sent_all_child_num_0'] += sample_counts[ 'sent_all_child_num_0'] eval_res['sent_all_child_pred'].append( eval_data['sent_all_child_pred']) eval_res['sent_all_child_true'].append( eval_data['sent_all_child_true']) if args.use_token_contsel_loss: run_avg_losses[ 'token_contsel_loss'] = calc_running_avg_loss( sample_ind_losses['token_contsel_loss'], run_avg_losses['token_contsel_loss'], iter) counts['token_consel_num_correct'] += sample_counts[ 'token_consel_num_correct'] counts['token_consel_num'] += sample_counts[ 'token_consel_num'] if args.use_sent_imp_loss: run_avg_losses['sent_imp_loss'] = calc_running_avg_loss( sample_ind_losses['sent_imp_loss'], run_avg_losses['sent_imp_loss'], iter) if args.use_doc_imp_loss: run_avg_losses['doc_imp_loss'] = calc_running_avg_loss( sample_ind_losses['doc_imp_loss'], run_avg_losses['doc_imp_loss'], iter) iter += 1 batch = self.eval_batcher.next_batch() msg = 'Eval: loss: %f' % running_avg_loss print(msg) logger.debug(msg) if args.use_summ_loss: msg = 'Summ Eval: loss: %f' % run_avg_losses['summ_loss'] print(msg) logger.debug(msg) if args.use_sent_single_head_loss: msg = 'Single Sent Head Eval: loss: %f' % run_avg_losses[ 'sent_single_head_loss'] print(msg) logger.debug(msg) msg = 'Average Sent Single Head Accuracy: %f' % ( counts['sent_single_heads_num_correct'] / float(counts['sent_single_heads_num'])) print(msg) logger.debug(msg) if args.use_sent_all_head_loss: msg = 'All Sent Head Eval: loss: %f' % run_avg_losses[ 'sent_all_head_loss'] print(msg) logger.debug(msg) msg = 'Average Sent All Head Accuracy: %f' % ( counts['sent_all_heads_num_correct'] / float(counts['sent_all_heads_num'])) print(msg) logger.debug(msg) # msg = 'Average Sent All Head Class1 Accuracy: %f' % (counts['sent_all_heads_num_correct_1']/float(counts['sent_all_heads_num_1'])) # print(msg) # logger.debug(msg) # msg = 'Average Sent All Head Class0 Accuracy: %f' % (counts['sent_all_heads_num_correct_0']/float(counts['sent_all_heads_num_0'])) # print(msg) # logger.debug(msg) y_pred = np.concatenate(eval_res['sent_all_heads_pred']) y_true = np.concatenate(eval_res['sent_all_heads_true']) msg = classification_report(y_true, y_pred, labels=[0, 1]) print(msg) logger.debug(msg) if args.use_sent_all_child_loss: msg = 'All Sent Child Eval: loss: %f' % run_avg_losses[ 'sent_all_child_loss'] print(msg) logger.debug(msg) msg = 'Average Sent All Child Accuracy: %f' % ( counts['sent_all_child_num_correct'] / float(counts['sent_all_child_num'])) print(msg) logger.debug(msg) # msg = 'Average Sent All Child Class1 Accuracy: %f' % (counts['sent_all_child_num_correct_1']/float(counts['sent_all_child_num_1'])) # print(msg) # logger.debug(msg) # msg = 'Average Sent All Child Class0 Accuracy: %f' % (counts['sent_all_child_num_correct_0']/float(counts['sent_all_child_num_0'])) # print(msg) # logger.debug(msg) y_pred = np.concatenate(eval_res['sent_all_child_pred']) y_true = np.concatenate(eval_res['sent_all_child_true']) msg = classification_report(y_true, y_pred, labels=[0, 1]) print(msg) logger.debug(msg) if args.use_token_contsel_loss: msg = 'Token Contsel Eval: loss: %f' % run_avg_losses[ 'token_contsel_loss'] print(msg) logger.debug(msg) msg = 'Average token content sel Accuracy: %f' % ( counts['token_consel_num_correct'] / float(counts['token_consel_num'])) print(msg) logger.debug(msg) if args.use_sent_imp_loss: msg = 'Sent Imp Eval: loss: %f' % run_avg_losses['sent_imp_loss'] print(msg) logger.debug(msg) if args.use_doc_imp_loss: msg = 'Doc Imp Eval: loss: %f' % run_avg_losses['doc_imp_loss'] print(msg) logger.debug(msg) return running_avg_loss
class Evaluate(object): def __init__(self, model_file_path): self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', batch_size=config.batch_size, single_pass=True) time.sleep(15) model_name = os.path.basename(model_file_path) eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name)) if not os.path.exists(eval_dir): os.mkdir(eval_dir) self.summary_writer = SummaryWriter(eval_dir) self.model = Model(model_file_path, is_eval=True) def eval_one_batch(self, batch): enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder( enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask, enc_padding_sent_mask) s_t_1 = self.model.reduce_state(encoder_hidden) if config.use_maxpool_init_ctx: c_t_1 = max_encoder_output step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, enc_padding_sent_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss = torch.mean(batch_avg_loss) del enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage gc.collect() torch.cuda.empty_cache() return loss.item() def run_eval(self): running_avg_loss, iter = 0, 0 start = time.time() batch = self.batcher.next_batch() while batch is not None: loss = self.eval_one_batch(batch) running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) iter += 1 # if iter % 100 == 0: # self.summary_writer.flush() print_interval = 1000 if iter % print_interval == 0: print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval, time.time() - start, running_avg_loss)) start = time.time() batch = self.batcher.next_batch()
class BeamSearch(object): def __init__(self, args, model_file_path, save_path): model_name = os.path.basename(model_file_path) self.args = args self._decode_dir = os.path.join(config.log_root, save_path, 'decode_%s' % (model_name)) self._structures_dir = os.path.join(self._decode_dir, 'structures') self._sent_single_heads_dir = os.path.join(self._decode_dir, 'sent_heads_preds') self._sent_single_heads_ref_dir = os.path.join(self._decode_dir, 'sent_heads_ref') self._contsel_dir = os.path.join(self._decode_dir, 'content_sel_preds') self._contsel_ref_dir = os.path.join(self._decode_dir, 'content_sel_ref') self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') self._rouge_ref_file = os.path.join(self._decode_dir, 'rouge_ref.json') self._rouge_pred_file = os.path.join(self._decode_dir, 'rouge_pred.json') self.stat_res_file = os.path.join(self._decode_dir, 'stats.txt') self.sent_count_file = os.path.join(self._decode_dir, 'sent_used_counts.txt') for p in [ self._decode_dir, self._structures_dir, self._sent_single_heads_ref_dir, self._sent_single_heads_dir, self._contsel_ref_dir, self._contsel_dir, self._rouge_ref_dir, self._rouge_dec_dir ]: if not os.path.exists(p): os.mkdir(p) vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file, args) self.batcher = Batcher(args.decode_data_path, self.vocab, mode='decode', batch_size=args.beam_size, single_pass=True, args=args) self.batcher.setup_queues() time.sleep(30) self.model = Model(args, self.vocab).to(device) self.model.eval() def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) def extract_structures(self, batch, sent_attention_matrix, doc_attention_matrix, count, use_cuda, sent_scores): fileName = os.path.join(self._structures_dir, "%06d_struct.txt" % count) fp = open(fileName, "w") fp.write("Doc: " + str(count) + "\n") #exit(0) doc_attention_matrix = doc_attention_matrix[:, :] #this change yet to be tested! l = batch.enc_doc_lens[0].item() doc_sent_no = 0 # for i in range(l): # printstr = '' # sent = batch.enc_batch[0][i] # #scores = str_scores_sent[sent_no][0:l, 0:l] # token_count = 0 # for j in range(batch.enc_sent_lens[0][i].item()): # token = sent[j].item() # printstr += self.vocab.id2word(token)+" " # token_count = token_count + 1 # #print(printstr) # fp.write(printstr+"\n") # # scores = sent_attention_matrix[doc_sent_no][0:token_count, 0:token_count] # shape2 = sent_attention_matrix[doc_sent_no][0:token_count,0:token_count].size() # row = torch.ones([1, shape2[1]+1]).cuda() # column = torch.zeros([shape2[0], 1]).cuda() # new_scores = torch.cat([column, scores], dim=1) # new_scores = torch.cat([row, new_scores], dim=0) # # heads, tree_score = chu_liu_edmonds(new_scores.data.cpu().numpy().astype(np.float64)) # #print(heads, tree_score) # fp.write(str(heads)+" ") # fp.write(str(tree_score)+"\n") # doc_sent_no+=1 shape2 = doc_attention_matrix[0:l, 0:l + 1].size() row = torch.zeros([1, shape2[1]]).cuda() #column = torch.zeros([shape2[0], 1]).cuda() scores = doc_attention_matrix[0:l, 0:l + 1] #new_scores = torch.cat([column, scores], dim=1) new_scores = torch.cat([row, scores], dim=0) val, root_edge = torch.max(new_scores[:, 0], dim=0) root_score = torch.zeros([shape2[0] + 1, 1]).cuda() root_score[root_edge] = 1 new_scores[:, 0] = root_score.squeeze() #print(new_scores) #print(new_scores.sum(dim=0)) #print(new_scores.sum(dim=1)) #print(new_scores.size()) heads, tree_score = chu_liu_edmonds( new_scores.data.cpu().numpy().astype(np.float64)) height = find_height(heads) leaf_nodes = leaf_node_proportion(heads) #print(heads, tree_score) fp.write("\n") sentences = str(batch.original_articles[0]).split("<split1>") for idx, sent in enumerate(sentences): fp.write(str(idx) + "\t" + str(sent) + "\n") #fp.write(str("\n".join(batch.original_articles[0].split("<split1>"))+"\n") fp.write(str(heads) + " ") fp.write(str(tree_score) + "\n") fp.write(str(height) + "\n") s = sent_scores[0].data.cpu().numpy() for val in s: fp.write(str(val)) fp.close() #exit() structure_info = dict() structure_info['heads'] = heads structure_info['height'] = height structure_info['leaf_nodes'] = leaf_nodes return structure_info def decode(self): start = time.time() counter = 0 sent_counter = [] avg_max_seq_len_list = [] copied_sequence_len = Counter() copied_sequence_per_sent = [] article_copy_id_count_tot = Counter() sentence_copy_id_count = Counter() novel_counter = Counter() repeated_counter = Counter() summary_sent_count = Counter() summary_sent = [] article_sent = [] summary_len = [] abstract_ref = [] abstract_pred = [] sentence_count = [] tot_sentence_id_count = Counter() height_avg = [] leaf_node_proportion_avg = [] precision_tree_dist = [] recall_tree_dist = [] batch = self.batcher.next_batch() height_counter = Counter() leaf_nodes_counter = Counter() sent_count_fp = open(self.sent_count_file, 'w') counts = { 'token_consel_num_correct': 0, 'token_consel_num': 0, 'sent_single_heads_num_correct': 0, 'sent_single_heads_num': 0, 'sent_all_heads_num_correct': 0, 'sent_all_heads_num': 0, 'sent_all_child_num_correct': 0, 'sent_all_child_num': 0 } no_batches_processed = 0 while batch is not None: # Run beam search to get best Hypothesis #start = time.process_time() has_summary, best_summary, sample_predictions, sample_counts, structure_info, adj_mat = self.get_decoded_outputs( batch, counter) #print('Time taken for decoder: ', time.process_time() - start) # token_contsel_tot_correct += token_consel_num_correct # token_contsel_tot_num += token_consel_num # sent_heads_tot_correct += sent_heads_num_correct # sent_heads_tot_num += sent_heads_num if args.predict_contsel_tags: no_words = batch.enc_word_lens[0] prediction = sample_predictions['token_contsel_prediction'][ 0:no_words] ref = batch.contsel_tags[0] write_tags(prediction, ref, counter, self._contsel_dir, self._contsel_ref_dir) counts['token_consel_num_correct'] += sample_counts[ 'token_consel_num_correct'] counts['token_consel_num'] += sample_counts['token_consel_num'] if args.predict_sent_single_head: no_sents = batch.enc_doc_lens[0] prediction = sample_predictions[ 'sent_single_heads_prediction'][0:no_sents].tolist() ref = batch.original_parent_heads[0] write_tags(prediction, ref, counter, self._sent_single_heads_dir, self._sent_single_heads_ref_dir) counts['sent_single_heads_num_correct'] += sample_counts[ 'sent_single_heads_num_correct'] counts['sent_single_heads_num'] += sample_counts[ 'sent_single_heads_num'] if args.predict_sent_all_head: counts['sent_all_heads_num_correct'] += sample_counts[ 'sent_all_heads_num_correct'] counts['sent_all_heads_num'] += sample_counts[ 'sent_all_heads_num'] if args.predict_sent_all_child: counts['sent_all_child_num_correct'] += sample_counts[ 'sent_all_child_num_correct'] counts['sent_all_child_num'] += sample_counts[ 'sent_all_child_num'] if has_summary == False: batch = self.batcher.next_batch() continue # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if self.args.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] summary_len.append(len(decoded_words)) assert adj_mat is not None, "Explicit matrix is none." assert structure_info['heads'] is not None, "Heads is none." precision, recall = tree_distance( structure_info['heads'], adj_mat.cpu().data.numpy()[0, :, :]) if precision is not None and recall is not None: precision_tree_dist.append(precision) recall_tree_dist.append(recall) height_counter[structure_info['height']] += 1 height_avg.append(structure_info['height']) leaf_node_proportion_avg.append(structure_info['leaf_nodes']) leaf_nodes_counter[np.floor(structure_info['leaf_nodes'] * 10)] += 1 abstract_ref.append(" ".join(original_abstract_sents)) abstract_pred.append(" ".join(decoded_words)) sent_res = get_sent_dist(" ".join(decoded_words), batch.original_articles[0].decode(), minimum_seq=self.args.minimum_seq) sent_counter.append( (sent_res['seen_sent'], sent_res['article_sent'])) summary_len.append(sent_res['summary_len']) summary_sent.append(sent_res['summary_sent']) summary_sent_count[sent_res['summary_sent']] += 1 article_sent.append(sent_res['article_sent']) if sent_res['avg_copied_seq_len'] is not None: avg_max_seq_len_list.append(sent_res['avg_copied_seq_len']) copied_sequence_per_sent.append( np.average( list(sent_res['counter_summary_sent_id'].values()))) copied_sequence_len.update(sent_res['counter_copied_sequence_len']) sentence_copy_id_count.update(sent_res['counter_summary_sent_id']) article_copy_id_count_tot.update( sent_res['counter_article_sent_id']) novel_counter.update(sent_res['novel_ngram_counter']) repeated_counter.update(sent_res['repeated_ngram_counter']) sent_count_fp.write( str(counter) + "\t" + str(sent_res['article_sent']) + "\t" + str(sent_res['seen_sent']) + "\n") write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) batch = self.batcher.next_batch() counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() #print('Time taken for rest: ', time.process_time() - start) if args.decode_for_subset: if counter == 1000: break print("Decoder has finished reading dataset for single_pass.") fp = open(self.stat_res_file, 'w') percentages = [ float(len(seen_sent)) / float(sent_count) for seen_sent, sent_count in sent_counter ] avg_percentage = sum(percentages) / float(len(percentages)) nosents = [len(seen_sent) for seen_sent, sent_count in sent_counter] avg_nosents = sum(nosents) / float(len(nosents)) res = dict() res['avg_percentage_seen_sent'] = avg_percentage res['avg_nosents'] = avg_nosents res['summary_len'] = summary_sent_count res['avg_summary_len'] = np.average(summary_len) res['summary_sent'] = np.average(summary_sent) res['article_sent'] = np.average(article_sent) res['avg_copied_seq_len'] = np.average(avg_max_seq_len_list) res['avg_sequences_per_sent'] = np.average(copied_sequence_per_sent) res['counter_copied_sequence_len'] = copied_sequence_len res['counter_summary_sent_id'] = sentence_copy_id_count res['counter_article_sent_id'] = article_copy_id_count_tot res['novel_ngram_counter'] = novel_counter res['repeated_ngram_counter'] = repeated_counter fp.write("Summary metrics\n") for key in res: fp.write('{}: {}\n'.format(key, res[key])) fp.write("Structures metrics\n") fp.write("Average depth of RST tree: " + str(sum(height_avg) / len(height_avg)) + "\n") fp.write("Average proportion of leaf nodes in RST tree: " + str( sum(leaf_node_proportion_avg) / len(leaf_node_proportion_avg)) + "\n") fp.write("Precision of edges latent to explicit: " + str(np.average(precision_tree_dist)) + "\n") fp.write("Recall of edges latent to explicit: " + str(np.average(recall_tree_dist)) + "\n") fp.write("Tree height counter:\n") fp.write(str(height_counter) + "\n") fp.write("Tree leaf proportion counter:") fp.write(str(leaf_nodes_counter) + "\n") if args.predict_contsel_tags: fp.write("Avg token_contsel: " + str((counts['token_consel_num_correct'] / float(counts['token_consel_num'])))) if args.predict_sent_single_head: fp.write("Avg single sent heads: " + str((counts['sent_single_heads_num_correct'] / float(counts['sent_single_heads_num'])))) if args.predict_sent_all_head: fp.write("Avg all sent heads: " + str((counts['sent_all_heads_num_correct'] / float(counts['sent_all_heads_num'])))) if args.predict_sent_all_child: fp.write("Avg all sent child: " + str((counts['sent_all_child_num_correct'] / float(counts['sent_all_child_num'])))) fp.close() sent_count_fp.close() write_to_json_file(abstract_ref, self._rouge_ref_file) write_to_json_file(abstract_pred, self._rouge_pred_file) def get_decoded_outputs(self, batch, count): #batch should have only one example enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, \ enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, word_batch, word_padding_mask, enc_word_lens, \ enc_tags_batch, enc_sent_tags, enc_sent_token_mat, adj_mat, weighted_adj_mat, norm_adj_mat, \ parent_heads, undir_weighted_adj_mat = get_input_from_batch(batch, use_cuda, self.args) enc_adj_mat = adj_mat if args.use_weighted_annotations: if args.use_undirected_weighted_graphs: enc_adj_mat = undir_weighted_adj_mat else: enc_adj_mat = weighted_adj_mat encoder_output = self.model.encoder.forward_test( enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask, enc_padding_sent_mask, word_batch, word_padding_mask, enc_word_lens, enc_tags_batch, enc_sent_token_mat, enc_adj_mat) encoder_outputs, enc_padding_mask, encoder_last_hidden, max_encoder_output, \ enc_batch_extend_vocab, token_level_sentence_scores, sent_outputs, token_scores, \ sent_scores, sent_matrix, sent_level_rep = \ self.model.get_app_outputs(encoder_output, enc_padding_token_mask, enc_padding_sent_mask, enc_batch_extend_vocab, enc_sent_token_mat) mask = enc_padding_sent_mask[0].unsqueeze(0).repeat( enc_padding_sent_mask.size(1), 1) * enc_padding_sent_mask[0].unsqueeze(1).transpose(1, 0) mask = torch.cat((enc_padding_sent_mask[0].unsqueeze(1), mask), dim=1) mat = encoder_output['sent_attention_matrix'][0][:, :] * mask structure_info = self.extract_structures( batch, encoder_output['token_attention_matrix'], mat, count, use_cuda, encoder_output['sent_score']) counts = {} predictions = {} if args.predict_contsel_tags: pred = encoder_output['token_score'][0, :, :].view(-1, 2) token_contsel_gold = enc_tags_batch[0, :].view(-1) token_contsel_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) token_contsel_prediction[ token_contsel_gold == -1] = -2 # Explicitly set masked tokens as different from value in gold token_consel_num_correct = torch.sum( token_contsel_prediction.eq(token_contsel_gold)).item() token_consel_num = torch.sum(token_contsel_gold != -1).item() predictions['token_contsel_prediction'] = token_contsel_prediction counts['token_consel_num_correct'] = token_consel_num_correct counts['token_consel_num'] = token_consel_num if args.predict_sent_single_head: pred = encoder_output['sent_single_head_scores'][0, :, :] head_labels = parent_heads[0, :].view(-1) sent_single_heads_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) sent_single_heads_prediction[ head_labels == -1] = -2 # Explicitly set masked tokens as different from value in gold sent_single_heads_num_correct = torch.sum( sent_single_heads_prediction.eq(head_labels)).item() sent_single_heads_num = torch.sum(head_labels != -1).item() predictions[ 'sent_single_heads_prediction'] = sent_single_heads_prediction counts[ 'sent_single_heads_num_correct'] = sent_single_heads_num_correct counts['sent_single_heads_num'] = sent_single_heads_num if args.predict_sent_all_head: pred = encoder_output['sent_all_head_scores'][0, :, :, :] target = adj_mat[0, :, :].permute(0, 1).view(-1) sent_all_heads_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) sent_all_heads_prediction[ target == -1] = -2 # Explicitly set masked tokens as different from value in gold sent_all_heads_num_correct = torch.sum( sent_all_heads_prediction.eq(target)).item() sent_all_heads_num = torch.sum(target != -1).item() predictions[ 'sent_all_heads_prediction'] = sent_all_heads_prediction counts['sent_all_heads_num_correct'] = sent_all_heads_num_correct counts['sent_all_heads_num'] = sent_all_heads_num if args.predict_sent_all_child: pred = encoder_output['sent_all_child_scores'][0, :, :, :] target = adj_mat[0, :, :].view(-1) sent_all_child_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) sent_all_child_prediction[ target == -1] = -2 # Explicitly set masked tokens as different from value in gold sent_all_child_num_correct = torch.sum( sent_all_child_prediction.eq(target)).item() sent_all_child_num = torch.sum(target != -1).item() predictions[ 'sent_all_child_prediction'] = sent_all_child_prediction counts['sent_all_child_num_correct'] = sent_all_child_num_correct counts['sent_all_child_num'] = sent_all_child_num results = [] steps = 0 has_summary = False beams_sorted = [None] if args.predict_summaries: has_summary = True if (args.fixed_scorer): scorer_output = self.model.module.pretrained_scorer.forward_test( enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask, enc_padding_sent_mask, word_batch, word_padding_mask, enc_word_lens, enc_tags_batch) token_scores = scorer_output['token_score'] sent_scores = scorer_output['sent_score'].unsqueeze(1).repeat( 1, enc_padding_token_mask.size(2), 1, 1).view( enc_padding_token_mask.size(0), enc_padding_token_mask.size(1) * enc_padding_token_mask.size(2)) all_child, all_head = None, None if args.use_gold_annotations_for_decode: if args.use_weighted_annotations: if args.use_undirected_weighted_graphs: permuted_all_head = undir_weighted_adj_mat[:, :, :].permute( 0, 2, 1) all_head = permuted_all_head.clone() row_sums = torch.sum(permuted_all_head, dim=2, keepdim=True) all_head[row_sums.expand_as( permuted_all_head) != 0] = permuted_all_head[ row_sums.expand_as(permuted_all_head) != 0] / row_sums.expand_as(permuted_all_head)[ row_sums.expand_as(permuted_all_head) != 0] base_all_child = undir_weighted_adj_mat[:, :, :] all_child = base_all_child.clone() row_sums = torch.sum(base_all_child, dim=2, keepdim=True) all_child[row_sums.expand_as( base_all_child) != 0] = base_all_child[ row_sums.expand_as(base_all_child) != 0] / row_sums.expand_as(base_all_child)[ row_sums.expand_as(base_all_child) != 0] else: permuted_all_head = weighted_adj_mat[:, :, :].permute( 0, 2, 1) all_head = permuted_all_head.clone() row_sums = torch.sum(permuted_all_head, dim=2, keepdim=True) all_head[row_sums.expand_as( permuted_all_head) != 0] = permuted_all_head[ row_sums.expand_as(permuted_all_head) != 0] / row_sums.expand_as(permuted_all_head)[ row_sums.expand_as(permuted_all_head) != 0] base_all_child = weighted_adj_mat[:, :, :] all_child = base_all_child.clone() row_sums = torch.sum(base_all_child, dim=2, keepdim=True) all_child[row_sums.expand_as( base_all_child) != 0] = base_all_child[ row_sums.expand_as(base_all_child) != 0] / row_sums.expand_as(base_all_child)[ row_sums.expand_as(base_all_child) != 0] else: permuted_all_head = adj_mat[:, :, :].permute(0, 2, 1) all_head = permuted_all_head.clone() row_sums = torch.sum(permuted_all_head, dim=2, keepdim=True) all_head[row_sums.expand_as( permuted_all_head) != 0] = permuted_all_head[ row_sums.expand_as(permuted_all_head) != 0] / row_sums.expand_as(permuted_all_head)[ row_sums.expand_as(permuted_all_head) != 0] base_all_child = adj_mat[:, :, :] all_child = base_all_child.clone() row_sums = torch.sum(base_all_child, dim=2, keepdim=True) all_child[row_sums.expand_as(base_all_child) != 0] = base_all_child[ row_sums.expand_as(base_all_child) != 0] / row_sums.expand_as(base_all_child)[ row_sums.expand_as(base_all_child) != 0] # all_head = adj_mat[:, :, :].permute(0,2,1) + 0.00005 # row_sums = torch.sum(all_head, dim=2, keepdim=True) # all_head = all_head / row_sums # all_child = adj_mat[:, :, :] + 0.00005 # row_sums = torch.sum(all_child, dim=2, keepdim=True) # all_child = all_child / row_sums s_t_0 = self.model.reduce_state(encoder_last_hidden) if config.use_maxpool_init_ctx: c_t_0 = max_encoder_output dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [ Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context=c_t_0[0], coverage=(coverage_t_0[0] if self.args.is_coverage or self.args.bu_coverage_penalty else None)) for _ in range(args.beam_size) ] while steps < args.max_dec_steps and len(results) < args.beam_size: latest_tokens = [h.latest_token for h in beams] # cur_len = torch.stack([len(h.tokens) for h in beams]) latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ for t in latest_tokens] y_t_1 = Variable(torch.LongTensor(latest_tokens)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h = [] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if self.args.is_coverage or self.args.bu_coverage_penalty: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder( y_t_1, s_t_1, encoder_outputs, word_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, token_scores, sent_scores, sent_outputs, enc_sent_token_mat, all_head, all_child, sent_level_rep) if args.bu_coverage_penalty: penalty = torch.max(coverage_t, coverage_t.clone().fill_(1.0)).sum(-1) penalty -= coverage_t.size(-1) final_dist -= args.beta * penalty.unsqueeze(1).expand_as( final_dist) if args.bu_length_penalty: penalty = ((5 + steps + 1) / 6.0)**args.alpha final_dist = final_dist / penalty topk_log_probs, topk_ids = torch.topk(final_dist, args.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in range(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if self.args.is_coverage or self.args.bu_coverage_penalty else None) for j in range(args.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id( data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == args.beam_size or len( results) == args.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return has_summary, beams_sorted[ 0], predictions, counts, structure_info, undir_weighted_adj_mat