def _validation(self, valid_loader): with torch.no_grad(): self.model.eval() loss_tracker = LossTracker() pred_scores = [] true_scores = [] for i, (img, label) in enumerate(valid_loader): img = torch.cat(img) img = img.to(cfg.device) label = torch.cat(label) label = label.to(cfg.device) classification_output = self.model(img) loss = self.criterion(classification_output, label) ModelUtils.append_results(label, classification_output, pred_scores, true_scores) loss_tracker.increment_loss(loss) if i % 100 == 0: weighted_AUC = CompetitionMetric.alaska_weighted_auc(true_scores, pred_scores) loss_tracker.print_losses(self.epoch, i, len(valid_loader), weighted_AUC) weighted_AUC = CompetitionMetric.alaska_weighted_auc(true_scores, pred_scores) loss_dict = loss_tracker.write_dict(weighted_AUC) loss_tracker.print_losses(self.epoch, i, len(valid_loader), weighted_AUC) self.writer.write_scalars(loss_dict, tag='val', n_iter=self.train_step) self.scheduler.step(metrics=loss_tracker.loss.avg) lr = self.optimizer.param_groups[-1]['lr'] self.writer.write_scalars({'lr': lr}, tag='val', n_iter=self.train_step)
def _train(self, train_loader): # Show and log loss results every 100 steps loss_tracker = LossTracker() pred_scores = [] true_scores = [] self.model.train() print('Epoch: {} : LR = {}'.format(self.epoch, self.lr)) for i, img in enumerate(train_loader): self.train_step += 1 # Split the batch in 4 sub-batches, such that each sub-batch contains either the cover, JUNIWARD, JMiPOD or # UERD version of each image. batch_splits, labels = ModelUtils.batch_splitter(img, num_imgs=4) classification_out = [self.model(sub_batch) for sub_batch in batch_splits] losses = [self.criterion(c, l) for c, l in zip(classification_out, labels)] total_loss = sum(losses)/len(losses) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() for batch_labs, batch_res in zip(labels, classification_out): ModelUtils.append_results(batch_labs, batch_res, pred_scores, true_scores) loss_tracker.increment_loss(total_loss) if i % cfg.log_freq == 0 and i > 0: weighted_AUC = CompetitionMetric.alaska_weighted_auc(true_scores, pred_scores) pred_scores = [] true_scores = [] loss_dict = loss_tracker.write_dict(weighted_AUC) loss_tracker.print_losses(self.epoch, i, len(train_loader), weighted_AUC) loss_tracker = LossTracker() # Reinitialize the loss tracking self.writer.write_scalars(loss_dict, tag='train', n_iter=self.train_step) if i % cfg.save_freq == 0 and i > 0: ModelUtils.save_model(self)