def train(dataset, val_dataset, v, start_epoch=0): """Train the model, evaluate it and store it. Args: dataset (dataset.PairDataset): The training dataset. val_dataset (dataset.PairDataset): The evaluation dataset. v (vocab.Vocab): The vocabulary built from the training dataset. start_epoch (int, optional): The starting epoch number. Defaults to 0. """ DEVICE = torch.device("cuda" if config.is_cuda else "cpu") model = PGN(v) model.load_model() model.to(DEVICE) if config.fine_tune: # In fine-tuning mode, we fix the weights of all parameters except attention.wc. print('Fine-tuning mode.') for name, params in model.named_parameters(): if name != 'attention.wc.weight': params.requires_grad = False # forward print("loading data") train_data = SampleDataset(dataset.pairs, v) val_data = SampleDataset(val_dataset.pairs, v) print("initializing optimizer") # Define the optimizer. optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) train_dataloader = DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn) val_losses = np.inf if (os.path.exists(config.losses_path)): with open(config.losses_path, 'rb') as f: val_losses = pickle.load(f) # torch.cuda.empty_cache() # SummaryWriter: Log writer used for TensorboardX visualization. writer = SummaryWriter(config.log_path) # tqdm: A tool for drawing progress bars during training. # scheduled_sampler : A tool for choosing teacher_forcing or not num_epochs = len(range(start_epoch, config.epochs)) scheduled_sampler = ScheduledSampler(num_epochs) if config.scheduled_sampling: print('scheduled_sampling mode.') # teacher_forcing = True with tqdm(total=config.epochs) as epoch_progress: for epoch in range(start_epoch, config.epochs): print(config_info(config)) batch_losses = [] # Get loss of each batch. num_batches = len(train_dataloader) # set a teacher_forcing signal if config.scheduled_sampling: teacher_forcing = scheduled_sampler.teacher_forcing( epoch - start_epoch) else: teacher_forcing = True print('teacher_forcing = {}'.format(teacher_forcing)) with tqdm(total=num_batches) as batch_progress: for batch, data in enumerate(tqdm(train_dataloader)): x, y, x_len, y_len, oov, len_oovs = data assert not np.any(np.isnan(x.numpy())) if config.is_cuda: # Training with GPUs. x = x.to(DEVICE) y = y.to(DEVICE) x_len = x_len.to(DEVICE) len_oovs = len_oovs.to(DEVICE) model.train() # Sets the module in training mode. optimizer.zero_grad() # Clear gradients. # Calculate loss. Call model forward propagation loss = model(x, x_len, y, len_oovs, batch=batch, num_batches=num_batches, teacher_forcing=teacher_forcing) batch_losses.append(loss.item()) loss.backward() # Backpropagation. # Do gradient clipping to prevent gradient explosion. clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.attention.parameters(), config.max_grad_norm) optimizer.step() # Update weights. # Output and record epoch loss every 100 batches. if (batch % 32) == 0: batch_progress.set_description(f'Epoch {epoch}') batch_progress.set_postfix(Batch=batch, Loss=loss.item()) batch_progress.update() # Write loss for tensorboard. writer.add_scalar(f'Average loss for epoch {epoch}', np.mean(batch_losses), global_step=batch) # Calculate average loss over all batches in an epoch. epoch_loss = np.mean(batch_losses) epoch_progress.set_description(f'Epoch {epoch}') epoch_progress.set_postfix(Loss=epoch_loss) epoch_progress.update() avg_val_loss = evaluate(model, val_data, epoch) print('training loss:{}'.format(epoch_loss), 'validation loss:{}'.format(avg_val_loss)) # Update minimum evaluating loss. if (avg_val_loss < val_losses): torch.save(model.encoder, config.encoder_save_name) torch.save(model.decoder, config.decoder_save_name) torch.save(model.attention, config.attention_save_name) torch.save(model.reduce_state, config.reduce_state_save_name) val_losses = avg_val_loss with open(config.losses_path, 'wb') as f: pickle.dump(val_losses, f) writer.close()
def train(dataset, val_dataset, v, start_epoch=0): """Train the model, evaluate it and store it. Args: dataset (dataset.PairDataset): The training dataset. val_dataset (dataset.PairDataset): The evaluation dataset. v (vocab.Vocab): The vocabulary built from the training dataset. start_epoch (int, optional): The starting epoch number. Defaults to 0. """ torch.autograd.set_detect_anomaly(True) DEVICE = torch.device("cuda" if config.is_cuda else "cpu") model = PGN(v) model.load_model() model.to(DEVICE) if config.fine_tune: # In fine-tuning mode, we fix the weights of all parameters except attention.wc. logging.info('Fine-tuning mode.') for name, params in model.named_parameters(): if name != 'attention.wc.weight': params.requires_grad = False # forward logging.info("loading data") train_data = dataset val_data = val_dataset logging.info("initializing optimizer") # Define the optimizer. # optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) optimizer = optim.Adagrad( model.parameters(), lr=config.learning_rate, initial_accumulator_value=config.initial_accumulator_value) scheduler = StepLR(optimizer, step_size=10, gamma=0.2) # 学习率调整 train_dataloader = DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn) val_loss = np.inf if (os.path.exists(config.losses_path)): with open(config.losses_path, 'r') as f: val_loss = float(f.readlines()[-1].split("=")[-1]) logging.info("the last best val loss is: " + str(val_loss)) # torch.cuda.empty_cache() # SummaryWriter: Log writer used for TensorboardX visualization. writer = SummaryWriter(config.log_path) # tqdm: A tool for drawing progress bars during training. early_stopping_count = 0 logging.info("start training model {}, ".format(config.model_name) + \ "epoch : {}, ".format(config.epochs) + "batch_size : {}, ".format(config.batch_size) + "num batches: {}, ".format(len(train_dataloader))) for epoch in range(start_epoch, config.epochs): batch_losses = [] # Get loss of each batch. num_batches = len(train_dataloader) # with tqdm(total=num_batches//100) as batch_progress: for batch, data in enumerate(train_dataloader): x, y, x_len, y_len, oov, len_oovs, img_vec = data assert not np.any(np.isnan(x.numpy())) if config.is_cuda: # Training with GPUs. x = x.to(DEVICE) y = y.to(DEVICE) x_len = x_len.to(DEVICE) len_oovs = len_oovs.to(DEVICE) img_vec = img_vec.to(DEVICE) if batch == 0: logging.info("x: %s, shape: %s" % (x, x.shape)) logging.info("y: %s, shape: %s" % (y, y.shape)) logging.info("oov: %s" % oov) logging.info("img_vec: %s, shape: %s" % (img_vec, img_vec.shape)) model.train() # Sets the module in training mode. optimizer.zero_grad() # Clear gradients. loss = model(x, y, len_oovs, img_vec, batch=batch, num_batches=num_batches) batch_losses.append(loss.item()) loss.backward() # Backpropagation. # Do gradient clipping to prevent gradient explosion. clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.attention.parameters(), config.max_grad_norm) clip_grad_norm_(model.reduce_state.parameters(), config.max_grad_norm) optimizer.step() # Update weights. # scheduler.step() # # Output and record epoch loss every 100 batches. if (batch % 100) == 0: # batch_progress.set_description(f'Epoch {epoch}') # batch_progress.set_postfix(Batch=batch, # Loss=loss.item()) # batch_progress.update() # # Write loss for tensorboard. writer.add_scalar(f'Average_loss_for_epoch_{epoch}', np.mean(batch_losses), global_step=batch) logging.info('epoch: {}, batch:{}, training loss:{}'.format( epoch, batch, np.mean(batch_losses))) # Calculate average loss over all batches in an epoch. epoch_loss = np.mean(batch_losses) # epoch_progress.set_description(f'Epoch {epoch}') # epoch_progress.set_postfix(Loss=epoch_loss) # epoch_progress.update() avg_val_loss = evaluate(model, val_data, epoch) logging.info('epoch: {} '.format(epoch) + 'training loss:{} '.format(epoch_loss) + 'validation loss:{} '.format(avg_val_loss)) # Update minimum evaluating loss. if not os.path.exists(os.path.dirname(config.encoder_save_name)): os.mkdir(os.path.dirname(config.encoder_save_name)) if (avg_val_loss < val_loss): logging.info("saving model to ../saved_model/ %s" % config.model_name) torch.save(model.encoder, config.encoder_save_name) torch.save(model.decoder, config.decoder_save_name) torch.save(model.attention, config.attention_save_name) torch.save(model.reduce_state, config.reduce_state_save_name) val_loss = avg_val_loss with open(config.losses_path, 'a') as f: f.write(f"best val loss={val_loss}\n") else: early_stopping_count += 1 if early_stopping_count >= config.patience: logging.info( f'Validation loss did not decrease for {config.patience} epochs, stop training.' ) break writer.close()