class Trainer(): def __init__(self, model, vocab): self.model = model self.report = True self.train_data = get_examples(train_data, vocab) self.batch_num = int( np.ceil(len(self.train_data) / float(train_batch_size))) self.dev_data = get_examples(dev_data, vocab) self.test_data = get_examples(test_data, vocab) # criterion self.criterion = nn.CrossEntropyLoss() # label name self.target_names = vocab.target_names # optimizer self.optimizer = Optimizer(model.all_parameters) # count self.step = 0 self.early_stop = -1 self.best_train_f1, self.best_dev_f1 = 0, 0 self.last_epoch = epochs def train(self): logging.info('Start training...') for epoch in range(1, epochs + 1): logging.info("Epoch: %d/%d" % (epoch, epochs+1)) train_f1 = self._train(epoch) dev_f1 = self._eval(epoch) if self.best_dev_f1 <= dev_f1: logging.info( "Exceed history dev = %.2f, current dev = %.2f" % (self.best_dev_f1, dev_f1)) torch.save(self.model.state_dict(), save_model) self.best_train_f1 = train_f1 self.best_dev_f1 = dev_f1 self.early_stop = 0 else: self.early_stop += 1 if self.early_stop == early_stops: logging.info( "Eearly stop in epoch %d, best train: %.2f, dev: %.2f" % ( epoch - early_stops, self.best_train_f1, self.best_dev_f1)) self.last_epoch = epoch break def test(self): self.model.load_state_dict(torch.load(save_model)) self._eval(self.last_epoch + 1, test=True) def _train(self, epoch): self.optimizer.zero_grad() self.model.train() start_time = time.time() epoch_start_time = time.time() overall_losses = 0 losses = 0 batch_idx = 1 y_pred = [] y_true = [] for batch_data in data_iter(self.train_data, train_batch_size, shuffle=True): torch.cuda.empty_cache() batch_inputs, batch_labels = self.batch2tensor(batch_data) batch_outputs = self.model(batch_inputs) loss = self.criterion(batch_outputs, batch_labels) loss.backward() loss_value = loss.detach().cpu().item() losses += loss_value overall_losses += loss_value y_pred.extend(torch.max(batch_outputs, dim=1) [1].cpu().numpy().tolist()) y_true.extend(batch_labels.cpu().numpy().tolist()) nn.utils.clip_grad_norm_(self.optimizer.all_params, max_norm=clip) for optimizer, scheduler in zip(self.optimizer.optims, self.optimizer.schedulers): optimizer.step() scheduler.step() self.optimizer.zero_grad() self.step += 1 if batch_idx % log_interval == 0: elapsed = time.time() - start_time lrs = self.optimizer.get_lr() logging.info( '|step {:3d} | batch {:3d}/{:3d} | lr{} | loss {:.4f} | s/batch {:.2f}'.format( self.step, batch_idx, self.batch_num, lrs, losses / log_interval, elapsed / log_interval)) losses = 0 start_time = time.time() batch_idx += 1 overall_losses /= self.batch_num during_time = time.time() - epoch_start_time # reformat overall_losses = reformat(overall_losses, 4) score, f1 = get_score(y_true, y_pred) logging.info( '| epoch {:3d} | score {} | f1 {} | loss {:.4f} | time {:.2f}'.format(epoch, score, f1, overall_losses, during_time)) if set(y_true) == set(y_pred) and self.report: report = classification_report( y_true, y_pred, digits=4, target_names=self.target_names) logging.info('\n' + report) return f1 def _eval(self, epoch, test=False): self.model.eval() start_time = time.time() data = self.test_data if test else self.dev_data y_pred = [] y_true = [] with torch.no_grad(): for batch_data in data_iter(data, test_batch_size, shuffle=False): torch.cuda.empty_cache() batch_inputs, batch_labels = self.batch2tensor(batch_data) batch_outputs = self.model(batch_inputs) y_pred.extend(torch.max(batch_outputs, dim=1) [1].cpu().numpy().tolist()) y_true.extend(batch_labels.cpu().numpy().tolist()) score, f1 = get_score(y_true, y_pred) during_time = time.time() - start_time if test: df = pd.DataFrame({'label': y_pred}) df.to_csv(save_test, index=False, sep=',') else: logging.info( '| epoch {:3d} | dev | score {} | f1 {} | time {:.2f}'.format(epoch, score, f1, during_time)) if set(y_true) == set(y_pred) and self.report: report = classification_report( y_true, y_pred, digits=4, target_names=self.target_names) logging.info('\n' + report) return f1 def batch2tensor(self, batch_data): ''' [[label, doc_len, [[sent_len, [sent_id0, ...], [sent_id1, ...]], ...]] ''' batch_size = len(batch_data) doc_labels = [] doc_lens = [] doc_max_sent_len = [] for doc_data in batch_data: doc_labels.append(doc_data[0]) doc_lens.append(doc_data[1]) sent_lens = [sent_data[0] for sent_data in doc_data[2]] max_sent_len = max(sent_lens) doc_max_sent_len.append(max_sent_len) max_doc_len = max(doc_lens) max_sent_len = max(doc_max_sent_len) batch_inputs1 = torch.zeros( (batch_size, max_doc_len, max_sent_len), dtype=torch.int64) batch_inputs2 = torch.zeros( (batch_size, max_doc_len, max_sent_len), dtype=torch.int64) batch_masks = torch.zeros( (batch_size, max_doc_len, max_sent_len), dtype=torch.float32) batch_labels = torch.LongTensor(doc_labels) for b in range(batch_size): for sent_idx in range(doc_lens[b]): sent_data = batch_data[b][2][sent_idx] for word_idx in range(sent_data[0]): batch_inputs1[b, sent_idx, word_idx] = sent_data[1][word_idx] batch_inputs2[b, sent_idx, word_idx] = sent_data[2][word_idx] batch_masks[b, sent_idx, word_idx] = 1 if use_cuda: batch_inputs1 = batch_inputs1.to(device) batch_inputs2 = batch_inputs2.to(device) batch_masks = batch_masks.to(device) batch_labels = batch_labels.to(device) return (batch_inputs1, batch_inputs2, batch_masks), batch_labels