class BeamSearch(object): def __init__(self, model_file_path): model_name = os.path.basename(model_file_path) self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name)) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', batch_size=config.beam_size, single_pass=True) # time.sleep(15) self.model = Model(model_file_path, is_eval=True) def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir) def beam_search(self, batch): #batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \ get_input_from_batch(batch, use_cuda) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_0 = self.model.reduce_state(encoder_hidden) dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [ Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context=c_t_0[0], coverage=(coverage_t_0[0] if config.is_coverage else None)) for _ in xrange(config.beam_size) ] results = [] steps = 0 while steps < config.max_dec_steps and len(results) < config.beam_size: # to do here # ... latest_tokens = [h.latest_token for h in beams] # latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ # for t in latest_tokens] y_t_1 = np.zeros((len(latest_tokens), config.emb_dim)) for i, t in enumerate(latest_tokens): if t < self.vocab.size(): w = self.vocab.id2word(t) else: idx = t - self.vocab.size() if idx >= len(batch.art_oovs[i]): w = data.UNKNOWN_TOKEN else: try: w = batch.art_oovs[i][idx] except Exception as e: print(e.message) if w == data.START_DECODING: embedd = embedding.start_decoding_embedd elif w == data.STOP_DECODING: embedd = embedding.stop_decoding_embedd elif w == data.PAD_TOKEN: embedd = embedding.padding_embedd elif w == data.UNKNOWN_TOKEN: embedd = embedding.unknown_decoding_embedd else: embedd = embedding.fasttext.get_numpy_vector(w.lower()) y_t_1[i] = embedd # y_t_1 = Variable(torch.LongTensor(latest_tokens)) y_t_1 = Variable(torch.FloatTensor(y_t_1)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h = [] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if config.is_coverage: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps) topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in xrange(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if config.is_coverage else None) for j in xrange(config.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == config.beam_size or len( results) == config.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return beams_sorted[0]
class Train(object): def __init__(self): self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.train_data_path, self.vocab, mode='train', batch_size=config.batch_size, single_pass=False) time.sleep(15) self.val_batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', batch_size=config.batch_size, single_pass=False) time.sleep(15) train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) if not os.path.exists(train_dir): os.mkdir(train_dir) self.model_dir = os.path.join(train_dir, 'model') if not os.path.exists(self.model_dir): os.mkdir(self.model_dir) self.summary_writer = tf.compat.v1.summary.FileWriter(train_dir) def save_model(self, running_avg_loss, iter): state = { 'iter': iter, 'encoder_state_dict': self.model.encoder.state_dict(), 'decoder_state_dict': self.model.decoder.state_dict(), 'reduce_state_dict': self.model.reduce_state.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_loss } model_save_path = os.path.join( self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) torch.save(state, model_save_path) def setup_train(self, model_file_path=None): self.model = Model(model_file_path) params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \ list(self.model.reduce_state.parameters()) initial_lr = config.lr_coverage if config.is_coverage else config.lr self.optimizer = Adagrad( params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc) self.scheduler = ExponentialLR(self.optimizer, gamma=0.99) start_iter, start_loss = 0, 0 if model_file_path is not None: #途中から始める場合 state = torch.load(model_file_path, map_location=lambda storage, location: storage) start_iter = state['iter'] start_loss = state['current_loss'] if not config.is_coverage: #coverageが無しの場合 self.optimizer.load_state_dict(state['optimizer']) if use_cuda: for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() return start_iter, start_loss def train_one_batch(self, batch, iter): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) self.optimizer.zero_grad() encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) step_losses = [] words = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) words.append(self.vocab.id2word(final_dist[0].argmax().item())) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) # print('step_loss',step_loss) # print('step_loss.size()',step_loss.size()) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) if iter % 100 == 0: print(words) print([self.vocab.id2word(idx.item()) for idx in dec_batch[0]]) print([self.vocab.id2word(idx.item()) for idx in target_batch[0]]) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss = torch.mean(batch_avg_loss) loss.backward() self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) self.optimizer.step() return loss.item() def eval_one_batch(self, batch): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_step_losses / dec_lens_var loss = torch.mean(batch_avg_loss) # print(loss) # print(type(loss)) # print(loss.data) # print(loss.data.item()) # return loss.data[0] return loss.data.item() def trainIters(self, n_iters, model_file_path=None): iter, running_avg_loss = self.setup_train(model_file_path) start = time.time() while iter < n_iters: batch = self.batcher.next_batch() loss = self.train_one_batch(batch, iter) val_loss = None if iter % 100 == 0: val_batch = self.val_batcher.next_batch() val_loss = self.eval_one_batch(val_batch) # print("val_loss",val_loss) self.scheduler.step() print("lr", self.scheduler.get_lr()) running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) iter += 1 if iter % 100 == 0: self.summary_writer.flush() print_interval = 1 if iter % print_interval == 0: if val_loss is not None: print( 'steps %d, seconds for %d batch: %.2f , loss: %f , eval_loss: %f' % (iter, print_interval, time.time() - start, loss, val_loss)) else: print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval, time.time() - start, loss)) start = time.time() if iter % 1000 == 0: self.save_model(running_avg_loss, iter)