def train(config: Transformer, model: Transformer, optimizer: ScheduledOptim, train_loader: DataLoader, eval_loader: DataLoader = None, device='cpu'): """train function""" model.train() best_eval_accuracy = -float('Inf') for epoch in range(config.epochs): logger.info("Epoch: {}".format(epoch)) total_loss, n_word_total, n_word_correct = 0, 0, 0 for ids, sample in enumerate(tqdm(train_loader)): for k, v in sample.items(): sample[k] = v.to(device) input_ids, decoder_input_ids, decoder_target_ids = ( sample['input_ids'], sample['decode_input_ids'], sample['decode_label_ids']) optimizer.zero_grad() logits = model(input_ids, decoder_input_ids) loss, n_correct, n_word = cal_performance( logits, gold=decoder_target_ids, trg_pad_idx=config.pad_idx, smoothing=config.label_smoothing) loss.backward() optimizer.step_and_update_lr() # note keeping n_word_total += n_word n_word_correct += n_correct total_loss += loss.item() loss_per_word = total_loss / n_word_total accuracy = n_word_correct / n_word_total logger.info("The {} epoch train loss: {}, train accuray: {}".format( epoch, loss_per_word, accuracy)) if eval_loader is not None: eval_loss, eval_accuracy = evaluate(config, model, eval_loader=eval_loader, device=device) if eval_accuracy > best_eval_accuracy: best_eval_accuracy = eval_accuracy # 保存最佳模型 model_save = model.module if hasattr(model, "module") else model model_file = os.path.join(config.save_dir, "checkpoint_{}.pt".format(epoch)) torch.save(model_save.state_dict(), f=model_file) if epoch % config.save_epoch == 0: model_save = model.module if hasattr(model, "module") else model model_file = os.path.join(config.save_dir, "checkpoint_{}.pt".format(epoch)) torch.save(model_save.state_dict(), f=model_file) return model
def train_per_epoch(opt: Namespace, model: Transformer, optimizer: ScheduledAdam, train_data, src_vocab, trg_vocab) -> dict: model.train() start_time = datetime.now() total_loss = total_word = total_corrected_word = 0 for i, batch in tqdm(enumerate(train_data), total=len(train_data), leave=False): src_input, trg_input, y_true = _prepare_batch_data(batch, opt.device) # Forward optimizer.zero_grad() y_pred = model(src_input, trg_input) # DEBUG pred_sentence = to_sentence(y_pred[0], trg_vocab) true_sentence = to_sentence(batch.trg[:, 0], trg_vocab) print(pred_sentence) print(true_sentence) import ipdb ipdb.set_trace() # Backward and update parameters loss = calculate_loss(y_pred, y_true, opt.trg_pad_idx, trg_vocab) n_word, n_corrected = calculate_performance(y_pred, y_true, opt.trg_pad_idx) loss.backward() optimizer.step() # Training Logs total_loss += loss.item() total_word += n_word total_corrected_word += n_corrected loss_per_word = total_loss / total_word accuracy = total_corrected_word / total_word return { 'total_seconds': (datetime.now() - start_time).total_seconds(), 'total_loss': total_loss, 'total_word': total_word, 'total_corrected_word': total_corrected_word, 'loss_per_word': loss_per_word, 'accuracy': accuracy }