class Trainer:
    def __init__(self,
                 args,
                 model,
                 criterion,
                 optimizer,
                 wandb,
                 scheduler=None):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.scheduler = scheduler
        self.optimizer = optimizer
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.train_loader, self.test_loader = load_split_train_test(args)
        self.loss = {'train': AverageMeter(), 'test': AverageMeter()}
        self.accuracy = {'train': AverageMeter(), 'test': AverageMeter()}
        self.logger = Logger(wandb, args, len(self.train_loader.dataset))
        self.epoch = 0
        self.min_accuracy = 0
        self.predlist = torch.zeros(0, dtype=torch.long, device='cpu')
        self.lbllist = torch.zeros(0, dtype=torch.long, device='cpu')

    def before_training_step(self):
        self.model.train()
        self.epoch += 1

        if self.scheduler is not None:
            self.scheduler.step()

    def after_training_step(self):
        self.logger.epoch_log(self.accuracy, self.loss, 'train')

    def training_step(self):
        self.before_training_step()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device,
                                   dtype=torch.float), target.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            self.loss['train'].update(loss.item())
            self.accuracy['train'].update(accuracy(output, target)[0])
            loss.backward()
            self.optimizer.step()
            self.logger.batch_log(self.epoch, batch_idx, loss.item())

        self.after_training_step()

    def before_validation_step(self):
        self.model.eval()

    def after_validation_step(self):
        if self.accuracy['test'].avg > self.min_accuracy:
            torch.save(
                self.model.state_dict(),
                os.path.join(self.logger.wandb.run.dir, f"best_model.pth"))
            self.min_accuracy = self.accuracy['test'].avg
            self.conf_mat = confusion_matrix(self.lbllist.numpy(),
                                             self.predlist.numpy())
            self.predlist = torch.zeros(0, dtype=torch.long, device='cpu')
            self.lbllist = torch.zeros(0, dtype=torch.long, device='cpu')
        else:
            self.predlist = torch.zeros(0, dtype=torch.long, device='cpu')
            self.lbllist = torch.zeros(0, dtype=torch.long, device='cpu')
        self.logger.epoch_log(self.accuracy, self.loss, 'test')
        self.loss = {
            stage: meter.reset()
            for stage, meter in self.loss.items()
        }
        self.accuracy = {
            stage: meter.reset()
            for stage, meter in self.accuracy.items()
        }

    def validation_step(self):
        self.before_validation_step()
        with torch.no_grad():
            for i, (data, target) in enumerate(self.test_loader):
                data, target = data.to(
                    self.device, dtype=torch.float), target.to(self.device)
                output = self.model(data)
                _, preds = torch.max(output, 1)
                # Append batch prediction results
                self.predlist = torch.cat(
                    [self.predlist, preds.view(-1).cpu()])
                self.lbllist = torch.cat([self.lbllist, target.view(-1).cpu()])
                self.loss['test'].update(self.criterion(output, target).item())
                self.accuracy['test'].update(accuracy(output, target)[0])
        self.after_validation_step()

    def run_training(self):
        self.model.to(self.device)
        for _ in range(1, self.args.epochs + 1):
            self.training_step()
            self.validation_step()
        # Confusion matrix
        # conf_mat = confusion_matrix(self.lbllist.numpy(), self.predlist.numpy())

        # Per-class accuracy
        class_accuracy = 100 * self.conf_mat.diagonal() / self.conf_mat.sum(1)
        self.logger.final_accuracy(class_accuracy)
        classes = self.train_loader.dataset.dataset.classes
        conf_mat = sns.heatmap(self.conf_mat,
                               annot=True,
                               fmt='g',
                               xticklabels=classes,
                               yticklabels=classes)
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        figure = conf_mat.get_figure()
        dataset_name = self.args.datapath.split('/')[-1]
        figure.savefig(
            f'confusion_matrix_{self.args.model_name}_{dataset_name}.png')