class Train(object): def __init__(self): self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.train_data_path, self.vocab, mode='train', batch_size=config.batch_size, single_pass=False) time.sleep(15) train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) if not os.path.exists(train_dir): os.mkdir(train_dir) self.model_dir = os.path.join(train_dir, 'model') if not os.path.exists(self.model_dir): os.mkdir(self.model_dir) self.summary_writer = tf.summary.FileWriter(train_dir) def save_model(self, running_avg_loss, iter): state = { 'iter': iter, 'encoder_state_dict': self.model.encoder.state_dict(), 'section_encoder_state_dict': self.model.section_encoder.state_dict(), 'sentence_filterer_state_dict': self.model.sentence_filterer.state_dict(), 'decoder_state_dict': self.model.decoder.state_dict(), 'reduce_state_dict': self.model.reduce_state.state_dict(), 'section_reduce_state_dict': self.model.section_reduce_state.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_loss } model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time()))) torch.save(state, model_save_path) def setup_train(self, model_file_path=None): self.model = Model(model_file_path) params = list(self.model.encoder.parameters()) + \ list(self.model.section_encoder.parameters()) + \ list(self.model.sentence_filterer.parameters()) + \ list(self.model.decoder.parameters()) + \ list(self.model.reduce_state.parameters()) + \ list(self.model.section_reduce_state.parameters()) initial_lr = config.lr_coverage if config.is_coverage else config.lr self.optimizer = AdagradCustom(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc) start_iter, start_loss = 0, 0 if model_file_path is not None: state = torch.load(model_file_path, map_location= lambda storage, location: storage) start_iter = state['iter'] start_loss = state['current_loss'] if not config.is_coverage and not config.is_sentence_filtering: self.optimizer.load_state_dict(state['optimizer']) if use_cuda: for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() return start_iter, start_loss def train_one_batch(self, batch): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage, sent_lens = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) self.optimizer.zero_grad() encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) if config.use_maxpool_init_ctx: c_t_1 = max_encoder_output gamma = None if config.is_sentence_filtering: gamma, sent_dists = self.model.sentence_filterer(encoder_outputs, sent_lens) section_outputs, section_hidden = self.model.section_encoder(s_t_1) s_t_1 = self.model.section_reduce_state(section_hidden) step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1, encoder_outputs, section_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, gamma) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage.view(*attn_dist.shape)), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses/dec_lens_var loss = torch.mean(batch_avg_loss) if config.is_sentence_filtering: sim_scores = torch.FloatTensor(batch.sim_scores) if use_cuda: sim_scores = sim_scores.cuda() sent_filter_loss = F.binary_cross_entropy(sent_dists, sim_scores) loss += config.sent_loss_wt * sent_filter_loss loss.backward() clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.section_encoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.sentence_filterer.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm) clip_grad_norm_(self.model.section_reduce_state.parameters(), config.max_grad_norm) self.optimizer.step() return loss.item() def trainIters(self, n_iters, model_file_path=None): iter, running_avg_loss = self.setup_train(model_file_path) start = time.time() while iter < n_iters: print('\rBatch %d' % iter, end="") batch = self.batcher.next_batch() loss = self.train_one_batch(batch) running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter) iter += 1 if iter % 5000 == 0: self.summary_writer.flush() print_interval = 10 if iter % print_interval == 0: print(' steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval, time.time() - start, loss)) start = time.time() if iter % 1000 == 0: self.save_model(running_avg_loss, iter)
class BeamSearch(object): def __init__(self, model_file_path): model_name = os.path.basename(model_file_path) self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name)) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', batch_size=config.beam_size, single_pass=True) time.sleep(15) self.model = Model(model_file_path, is_eval=True) def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() while batch is not None: # Run beam search to get best Hypothesis with torch.no_grad(): best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words(output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words print("===============SUMMARY=============") print(' '.join(decoded_words)) original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec'%(counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir) def beam_search(self, batch): #batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, sent_lens = \ get_input_from_batch(batch, use_cuda) encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens) s_t_0 = self.model.reduce_state(encoder_hidden) if config.use_maxpool_init_ctx: c_t_0 = max_encoder_output gamma = None if config.is_sentence_filtering: gamma, sent_dists = self.model.sentence_filterer(encoder_outputs, sent_lens) section_outputs, section_hidden = self.model.section_encoder(s_t_0) s_t_0 = self.model.section_reduce_state(section_hidden) dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context = c_t_0[0], coverage=(coverage_t_0[0] if config.is_coverage else None)) for _ in range(config.beam_size)] results = [] steps = 0 while steps < config.max_dec_steps and len(results) < config.beam_size: latest_tokens = [h.latest_token for h in beams] latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ for t in latest_tokens] y_t_1 = Variable(torch.LongTensor(latest_tokens)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h =[] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if config.is_coverage: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1, encoder_outputs, section_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, gamma) topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in range(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if config.is_coverage else None) for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == config.beam_size or len(results) == config.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return beams_sorted[0]
class Evaluate(object): def __init__(self, model_file_path): self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', batch_size=config.batch_size, single_pass=True) time.sleep(15) model_name = os.path.basename(model_file_path) eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name)) if not os.path.exists(eval_dir): os.mkdir(eval_dir) self.summary_writer = tf.summary.FileWriter(eval_dir) self.model = Model(model_file_path, is_eval=True) def eval_one_batch(self, batch): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) if config.use_maxpool_init_ctx: c_t_1 = max_encoder_output step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_step_losses / dec_lens_var loss = torch.mean(batch_avg_loss) return loss.data[0] def run_eval(self): running_avg_loss, iter = 0, 0 start = time.time() batch = self.batcher.next_batch() while batch is not None: loss = self.eval_one_batch(batch) running_avg_loss = calc_running_avg_loss(loss.item(), running_avg_loss, self.summary_writer, iter) iter += 1 if iter % 100 == 0: self.summary_writer.flush() print_interval = 1 if iter % print_interval == 0: print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval, time.time() - start, running_avg_loss)) start = time.time() batch = self.batcher.next_batch()