def _validation(self, valid_loader):
        with torch.no_grad():

            self.model.eval()
            loss_tracker = LossTracker()
            pred_scores = []
            true_scores = []

            for i, (img, label) in enumerate(valid_loader):

                img = torch.cat(img)
                img = img.to(cfg.device)
                label = torch.cat(label)
                label = label.to(cfg.device)

                classification_output = self.model(img)
                loss = self.criterion(classification_output, label)

                ModelUtils.append_results(label, classification_output, pred_scores, true_scores)
                loss_tracker.increment_loss(loss)

                if i % 100 == 0:
                    weighted_AUC = CompetitionMetric.alaska_weighted_auc(true_scores, pred_scores)
                    loss_tracker.print_losses(self.epoch, i, len(valid_loader), weighted_AUC)

            weighted_AUC = CompetitionMetric.alaska_weighted_auc(true_scores, pred_scores)
            loss_dict = loss_tracker.write_dict(weighted_AUC)
            loss_tracker.print_losses(self.epoch, i, len(valid_loader), weighted_AUC)
            self.writer.write_scalars(loss_dict, tag='val', n_iter=self.train_step)
            self.scheduler.step(metrics=loss_tracker.loss.avg)
            lr = self.optimizer.param_groups[-1]['lr']
            self.writer.write_scalars({'lr': lr}, tag='val', n_iter=self.train_step)
    def _train(self, train_loader):
        # Show and log loss results every 100 steps
        loss_tracker = LossTracker()
        pred_scores = []
        true_scores = []

        self.model.train()

        print('Epoch: {} : LR = {}'.format(self.epoch, self.lr))

        for i, img in enumerate(train_loader):

            self.train_step += 1

            # Split the batch in 4 sub-batches, such that each sub-batch contains either the cover, JUNIWARD, JMiPOD or
            # UERD version of each image.

            batch_splits, labels = ModelUtils.batch_splitter(img, num_imgs=4)

            classification_out = [self.model(sub_batch) for sub_batch in batch_splits]
            losses = [self.criterion(c, l) for c, l in zip(classification_out, labels)]
            total_loss = sum(losses)/len(losses)

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            for batch_labs, batch_res in zip(labels, classification_out):
                ModelUtils.append_results(batch_labs, batch_res, pred_scores, true_scores)

            loss_tracker.increment_loss(total_loss)

            if i % cfg.log_freq == 0 and i > 0:
                weighted_AUC = CompetitionMetric.alaska_weighted_auc(true_scores, pred_scores)
                pred_scores = []
                true_scores = []

                loss_dict = loss_tracker.write_dict(weighted_AUC)
                loss_tracker.print_losses(self.epoch, i, len(train_loader), weighted_AUC)
                loss_tracker = LossTracker()  # Reinitialize the loss tracking
                self.writer.write_scalars(loss_dict, tag='train', n_iter=self.train_step)

            if i % cfg.save_freq == 0 and i > 0:
                ModelUtils.save_model(self)