class FER2013Trainer(Trainer):
    """for classification task"""
    def __init__(self, model, train_set, val_set, test_set, configs,
                 train_name):
        super().__init__()
        print("Start trainer..")
        print(configs)

        # load config
        self._configs = configs
        self._lr = self._configs["lr"]
        self._batch_size = self._configs["batch_size"]
        self._momentum = self._configs["momentum"]
        self._weight_decay = self._configs["weight_decay"]
        self._distributed = self._configs["distributed"]
        self._num_workers = self._configs["num_workers"]
        self._device = torch.device(self._configs["device"])
        self._max_epoch_num = self._configs["max_epoch_num"]
        self._max_plateau_count = self._configs["max_plateau_count"]

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._test_set = test_set
        self._model = model(
            in_channels=configs["in_channels"],
            num_classes=configs["num_classes"],
        )

        # self._model.fc = nn.Linear(512, 7)
        # self._model.fc = nn.Linear(256, 7)
        self._model = self._model.to(self._device)

        if self._distributed == 1:
            torch.distributed.init_process_group(backend="nccl")
            self._model = nn.parallel.DistributedDataParallel(self._model)

            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                worker_init_fn=lambda x: np.random.seed(x),
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )

            self._test_loader = DataLoader(
                self._test_set,
                batch_size=1,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )
        else:
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )
            self._test_loader = DataLoader(
                self._test_set,
                batch_size=1,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )

        # define loss function (criterion) and optimizer
        class_weights = [
            1.02660468,
            9.40661861,
            1.00104606,
            0.56843877,
            0.84912748,
            1.29337298,
            0.82603942,
        ]
        class_weights = torch.FloatTensor(np.array(class_weights))

        if self._configs["weighted_loss"] == 0:
            self._criterion = nn.CrossEntropyLoss().to(self._device)
        else:
            self._criterion = nn.CrossEntropyLoss(class_weights).to(
                self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs["plateau_patience"],
            min_lr=1e-6,
            verbose=True,
        )
        """ TODO set step size equal to configs
        self._scheduler = StepLR(
            self._optimizer,
            step_size=self._configs['steplr']
        )
        """

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs["cwd"],
            self._configs["log_dir"],
            "{}_{}_{}".format(
                self._configs["arch"],
                self._configs["model_name"],
                self._start_time.strftime("%Y%b%d_%H.%M"),
            ),
        )
        self._writer = SummaryWriter(log_dir)
        self._train_loss_list = []
        self._train_acc_list = []
        self._val_loss_list = []
        self._val_acc_list = []
        self._best_val_loss = 1e9
        self._best_val_acc = 0
        self._best_train_loss = 1e9
        self._best_train_acc = 0
        self._test_acc = 0.0
        self._plateau_count = 0
        self._current_epoch_num = 0

        # for checkpoints
        self._checkpoint_dir = os.path.join(self._configs["cwd"],
                                            "saved/checkpoints")
        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(self._checkpoint_dir, exist_ok=True)

        self._checkpoint_path = os.path.join(
            self._checkpoint_dir,
            "{}_{}_{}".format(self._configs["arch"],
                              self._configs["model_name"], train_name),
        )

    def _train(self):
        self._model.train()
        train_loss = 0.0
        train_acc = 0.0

        for i, (images, targets) in tqdm(enumerate(self._train_loader),
                                         total=len(self._train_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            acc = accuracy(outputs, targets)[0]
            # acc = eval_metrics(targets, outputs, 2)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

        i += 1
        self._train_loss_list.append(train_loss / i)
        self._train_acc_list.append(train_acc / i)

    def _val(self):
        self._model.eval()
        val_loss = 0.0
        val_acc = 0.0

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._val_loader),
                                             total=len(self._val_loader),
                                             leave=False):
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                # compute output, measure accuracy and record loss
                outputs = self._model(images)

                loss = self._criterion(outputs, targets)
                acc = accuracy(outputs, targets)[0]

                val_loss += loss.item()
                val_acc += acc.item()

            i += 1
            self._val_loss_list.append(val_loss / i)
            self._val_acc_list.append(val_acc / i)

    def _calc_acc_on_private_test(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")
        f = open("private_test_log.txt", "w")
        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._test_loader),
                                             total=len(self._test_loader),
                                             leave=False):

                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                print(outputs.shape, outputs)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()
                f.writelines("{}_{}\n".format(i, acc.item()))

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        f.close()
        return test_acc

    def _calc_acc_on_private_test_with_tta(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test with tta..")
        f = open(
            "private_test_log_{}_{}.txt".format(self._configs["arch"],
                                                self._configs["model_name"]),
            "w",
        )

        with torch.no_grad():
            for idx in tqdm(range(len(self._test_set)),
                            total=len(self._test_set),
                            leave=False):
                images, targets = self._test_set[idx]
                targets = torch.LongTensor([targets])

                images = make_batch(images)
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                outputs = F.softmax(outputs, 1)

                # outputs.shape [tta_size, 7]
                outputs = torch.sum(outputs, 0)

                outputs = torch.unsqueeze(outputs, 0)
                # print(outputs.shape)
                # TODO: try with softmax first and see the change
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()
                f.writelines("{}_{}\n".format(idx, acc.item()))

            test_acc = test_acc / (idx + 1)
        print("Accuracy on private test with tta: {:.3f}".format(test_acc))
        f.close()
        return test_acc

    def train(self):
        """make a training job"""
        # print(self._model)

        # freeze the model
        """
        print('=' * 10)
        for idx, child in enumerate(self._model.children()):
            if idx < 6:
                print(child)
                print('=' * 10)
                
                for m in child.parameters():
                    m.requires_grad = False
          """

        # exit(0)

        try:
            while not self._is_stop():
                self._increase_epoch_num()
                self._train()
                self._val()

                self._update_training_state()
                self._logging()
        except KeyboardInterrupt:
            traceback.print_exc()
            pass

        # training stop
        try:
            # state = torch.load('saved/checkpoints/resatt18_rot30_2019Nov06_18.56')
            state = torch.load(self._checkpoint_path)
            if self._distributed:
                self._model.module.load_state_dict(state["net"])
            else:
                self._model.load_state_dict(state["net"])

            if not self._test_set.is_tta():
                self._test_acc = self._calc_acc_on_private_test()
            else:
                self._test_acc = self._calc_acc_on_private_test_with_tta()

            # self._test_acc = self._calc_acc_on_private_test()
            self._save_weights()
        except Exception as e:
            traceback.print_exc()
            pass

        consume_time = str(datetime.datetime.now() - self._start_time)
        self._writer.add_text(
            "Summary",
            "Converged after {} epochs, consume {}".format(
                self._current_epoch_num, consume_time[:-7]),
        )
        self._writer.add_text(
            "Results",
            "Best validation accuracy: {:.3f}".format(self._best_val_acc))
        self._writer.add_text(
            "Results",
            "Best training accuracy: {:.3f}".format(self._best_train_acc))
        self._writer.add_text(
            "Results", "Private test accuracy: {:.3f}".format(self._test_acc))
        self._writer.close()

    def _update_training_state(self):
        if self._val_acc_list[-1] > self._best_val_acc:
            self._save_weights()
            self._plateau_count = 0
            self._best_val_acc = self._val_acc_list[-1]
            self._best_val_loss = self._val_loss_list[-1]
            self._best_train_acc = self._train_acc_list[-1]
            self._best_train_loss = self._train_loss_list[-1]
        else:
            self._plateau_count += 1

        # self._scheduler.step(self._train_loss_list[-1])
        self._scheduler.step(100 - self._val_acc_list[-1])
        # self._scheduler.step()

    def _logging(self):
        consume_time = str(datetime.datetime.now() - self._start_time)

        message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._current_epoch_num,
            self._train_loss_list[-1],
            self._val_loss_list[-1],
            self._best_val_loss,
            self._train_acc_list[-1],
            self._val_acc_list[-1],
            self._best_val_acc,
            self._plateau_count,
            consume_time[:-7],
        )

        self._writer.add_scalar("Accuracy/Train", self._train_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Accuracy/Val", self._val_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/Train", self._train_loss_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/Val", self._val_loss_list[-1],
                                self._current_epoch_num)

        print(message)

    def _is_stop(self):
        """check stop condition"""
        return (self._plateau_count > self._max_plateau_count
                or self._current_epoch_num > (self._max_epoch_num - 1))

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _save_weights(self, test_acc=0.0):
        if self._distributed == 0:
            state_dict = self._model.state_dict()
        else:
            state_dict = self._model.module.state_dict()

        state = {
            **self._configs,
            "net": state_dict,
            "best_val_loss": self._best_val_loss,
            "best_val_acc": self._best_val_acc,
            "best_train_loss": self._best_train_loss,
            "best_train_acc": self._best_train_acc,
            "train_losses": self._train_loss_list,
            "val_loss_list": self._val_loss_list,
            "train_acc_list": self._train_acc_list,
            "val_acc_list": self._val_acc_list,
            "test_acc": self._test_acc,
        }

        torch.save(state, self._checkpoint_path)
Example #2
0
class CkTrainer(Trainer):
    def __init__(self, model, train_set, val_set, fold_idx, configs):
        super().__init__()
        print("Start trainer..")
        print(configs)

        # load config
        self._configs = configs
        self._lr = self._configs['lr']
        self._batch_size = self._configs['batch_size']
        self._momentum = self._configs['momentum']
        self._weight_decay = self._configs['weight_decay']
        self._distributed = self._configs['distributed']
        self._num_workers = self._configs['num_workers']
        self._device = torch.device(self._configs['device'])
        self._max_epoch_num = self._configs['max_epoch_num']
        self._max_plateau_count = self._configs['max_plateau_count']
        self._fold_idx = str(fold_idx)

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._model = model(
            in_channels=configs['in_channels'],
            num_classes=configs['num_classes'],
        )

        self._model = self._model.to(self._device)

        if self._distributed == 1:
            torch.distributed.init_process_group(backend='nccl')
            self._model = nn.parallel.DistributedDataParallel(
                self._model, find_unused_parameters=True)
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                worker_init_fn=lambda x: np.random.seed(x))
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x))
        else:
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )

        # define loss function (criterion) and optimizer
        class_weights = [
            1.02660468, 9.40661861, 1.00104606, 0.56843877, 0.84912748,
            1.29337298, 0.82603942
        ]
        class_weights = torch.FloatTensor(np.array(class_weights))
        # self._criterion = nn.CrossEntropyLoss(class_weights).to(self._device)
        self._criterion = nn.CrossEntropyLoss().to(self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs['plateau_patience'],
            min_lr=1e-6,
            verbose=True)

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs['cwd'], self._configs['log_dir'],
            "{}_{}_fold_{}".format(self._configs['arch'],
                                   self._configs['model_name'],
                                   self._fold_idx))

        self._writer = SummaryWriter(log_dir)
        self._train_loss_list = []
        self._train_acc_list = []
        self._val_loss_list = []
        self._val_acc_list = []
        self._best_val_loss = 1e9
        self._best_val_acc = 0
        self._best_train_loss = 1e9
        self._best_train_acc = 0
        self._plateau_count = 0
        self._current_epoch_num = 0

        # for checkpoints
        self._checkpoint_dir = os.path.join(self._configs['cwd'],
                                            'saved/checkpoints')
        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(self._checkpoint_dir, exist_ok=True)

        self._checkpoint_path = os.path.join(
            self._checkpoint_dir,
            "{}_{}_fold_{}".format(self._configs['arch'],
                                   self._configs['model_name'],
                                   self._fold_idx))

    def _train(self):
        self._model.train()
        train_loss = 0.
        train_acc = 0.

        for i, (images, targets) in tqdm(enumerate(self._train_loader),
                                         total=len(self._train_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            acc = accuracy(outputs, targets)[0]
            # acc = eval_metrics(targets, outputs, 2)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

        i += 1
        self._train_loss_list.append(train_loss / i)
        self._train_acc_list.append(train_acc / i)

    def _val(self):
        self._model.eval()
        val_loss = 0.
        val_acc = 0.

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._val_loader),
                                             total=len(self._val_loader),
                                             leave=False):
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                # compute output, measure accuracy and record loss
                outputs = self._model(images)

                loss = self._criterion(outputs, targets)
                acc = accuracy(outputs, targets)[0]

                val_loss += loss.item()
                val_acc += acc.item()

            i += 1
            self._val_loss_list.append(val_loss / i)
            self._val_acc_list.append(val_acc / i)

    def train(self):
        """make a training job"""
        print(self._model)

        try:
            while not self._is_stop():
                self._increase_epoch_num()
                self._train()
                self._val()

                self._update_training_state()
                self._logging()
        except KeyboardInterrupt:
            traceback.print_exc()
            pass

        consume_time = str(datetime.datetime.now() - self._start_time)
        self._writer.add_text(
            'Summary', 'Converged after {} epochs, consume {}'.format(
                self._current_epoch_num, consume_time[:-7]))
        self._writer.add_text(
            'Results',
            'Best validation accuracy: {:.3f}'.format(self._best_val_acc))
        self._writer.add_text(
            'Results',
            'Best training accuracy: {:.3f}'.format(self._best_train_acc))
        self._writer.close()

    def _update_training_state(self):
        if self._val_acc_list[-1] > self._best_val_acc:
            self._save_weights()
            self._best_val_acc = self._val_acc_list[-1]
            self._best_val_loss = self._val_loss_list[-1]
            self._best_train_acc = self._train_acc_list[-1]
            self._best_train_loss = self._train_loss_list[-1]

            self._plateau_count = 0
        else:
            self._plateau_count += 1

        self._scheduler.step(100 - self._val_acc_list[-1])

    def _logging(self):
        consume_time = str(datetime.datetime.now() - self._start_time)

        message = "\nFold {}  E{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._fold_idx, self._current_epoch_num, self._train_loss_list[-1],
            self._val_loss_list[-1], self._best_val_loss,
            self._train_acc_list[-1], self._val_acc_list[-1],
            self._best_val_acc, self._plateau_count, consume_time[:-7])

        self._writer.add_scalar('Accuracy/Train', self._train_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar('Accuracy/Val', self._val_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar('Loss/Train', self._train_loss_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar('Loss/Val', self._val_loss_list[-1],
                                self._current_epoch_num)
        print(message)

    def _is_stop(self):
        """check stop condition"""
        return (self._plateau_count > self._max_plateau_count
                or self._current_epoch_num > self._max_epoch_num)

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _save_weights(self):
        """save checkpoint"""
        if self._distributed == 0:
            state_dict = self._model.state_dict()
        else:
            state_dict = self._model.module.state_dict()

        state = {
            **self._configs,
            'net': state_dict,
            'best_val_loss': self._best_val_loss,
            'best_val_acc': self._best_val_acc,
            'best_train_loss': self._best_train_loss,
            'best_train_acc': self._best_train_acc,
            'train_losses': self._train_loss_list,
            'val_loss_list': self._val_loss_list,
            'train_acc_list': self._train_acc_list,
            'val_acc_list': self._val_acc_list,
        }

        torch.save(state, self._checkpoint_path)
class FER2013Trainer(Trainer):
    """for classification task"""
    def __init__(self, model, train_set, val_set, test_set, configs):
        super().__init__()
        print("Start trainer..")
        print(configs)

        # load config
        self._configs = configs
        self._lr = self._configs["lr"]
        self._batch_size = self._configs["batch_size"]
        self._momentum = self._configs["momentum"]
        self._weight_decay = self._configs["weight_decay"]
        self._distributed = self._configs["distributed"]
        self._num_workers = self._configs["num_workers"]
        self._device = torch.device(self._configs["device"])
        self._max_epoch_num = self._configs["max_epoch_num"]
        self._max_plateau_count = self._configs["max_plateau_count"]

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._test_set = test_set
        self._model = model(
            in_channels=configs["in_channels"],
            num_classes=configs["num_classes"],
        )

        # self._model.fc = nn.Linear(512, 7)
        self._model.fc = nn.Linear(2, 7)
        # self._model.fc = nn.Linear(256, 7)
        self._model = self._model.to(self._device)

        if self._distributed == 1:
            torch.distributed.init_process_group(backend="nccl")
            self._model = nn.parallel.DistributedDataParallel(
                self._model, find_unused_parameters=True)
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                worker_init_fn=lambda x: np.random.seed(x),
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )

            self._test_loader = DataLoader(
                self._test_set,
                batch_size=1,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )
        else:
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )
            self._test_loader = DataLoader(
                self._test_set,
                batch_size=1,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )

        # define loss function (criterion) and optimizer
        class_weights = [
            1.02660468,
            9.40661861,
            1.00104606,
            0.56843877,
            0.84912748,
            1.29337298,
            0.82603942,
        ]
        class_weights = torch.FloatTensor(np.array(class_weights))
        self._criterion = nn.CrossEntropyLoss(class_weights).to(self._device)
        # self._criterion = nn.CrossEntropyLoss().to(self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        # for center loss
        self._criterion_cent = CenterLoss(num_classes=7,
                                          feat_dim=2,
                                          use_gpu=True)
        self._optimizer_cent = RAdam(self._criterion_cent.parameters(),
                                     lr=self._configs["clr"])
        """
        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs['plateau_patience'],
            min_lr=1e-6,
            verbose=True
        )
        """

        # ''' TODO set step size equal to configs
        self._scheduler = StepLR(self._optimizer,
                                 step_size=self._configs["steplr"])
        self._center_scheduler = StepLR(self._optimizer_cent,
                                        step_size=self._configs["csteplr"])
        # '''

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs["cwd"],
            self._configs["log_dir"],
            "{}_{}_{}".format(
                self._configs["arch"],
                self._configs["model_name"],
                self._start_time.strftime("%Y%b%d_%H.%M"),
            ),
        )
        self._writer = SummaryWriter(log_dir)
        self._train_loss_list = []
        self._train_acc_list = []
        self._val_loss_list = []
        self._val_acc_list = []
        self._best_val_loss = 1e9
        self._best_val_acc = 0
        self._best_train_loss = 1e9
        self._best_train_acc = 0
        self._test_acc = 0.0
        self._plateau_count = 0
        self._current_epoch_num = 0

        # for checkpoints
        self._checkpoint_dir = os.path.join(self._configs["cwd"],
                                            "saved/checkpoints")
        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(self._checkpoint_dir, exist_ok=True)

        self._checkpoint_path = os.path.join(
            self._checkpoint_dir,
            "{}_{}_{}".format(
                self._configs["arch"],
                self._configs["model_name"],
                self._start_time.strftime("%Y%b%d_%H.%M"),
            ),
        )

    def _train(self):
        self._model.train()
        train_loss = 0.0
        train_acc = 0.0

        # for plot center lloss
        all_features, all_labels = [], []

        for i, (images, targets) in tqdm(enumerate(self._train_loader),
                                         total=len(self._train_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs, features = self._model(images)

            loss = self._criterion(outputs, targets)

            if self._current_epoch_num > 20:
                loss_cent = self._criterion_cent(features, targets)
                loss_cent *= self._configs["cweight"]
                loss = loss + loss_cent

            acc = accuracy(outputs, targets)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            self._optimizer_cent.zero_grad()
            loss.backward()
            self._optimizer.step()

            if self._current_epoch_num > 20:
                # by doing so, weight_cent would not impact on the learning of centers
                for param in self._criterion_cent.parameters():
                    param.grad.data *= 1.0 / self._configs["cweight"]
                self._optimizer_cent.step()

            all_features.append(features.data.cpu().numpy())
            all_labels.append(targets.data.cpu().numpy())

        i += 1
        self._train_loss_list.append(train_loss / i)
        self._train_acc_list.append(train_acc / i)

        # plot center
        all_features = np.concatenate(all_features, 0)
        all_labels = np.concatenate(all_labels, 0)
        self._plot_features(all_features, all_labels, prefix="train")

    def _val(self):
        self._model.eval()
        val_loss = 0.0
        val_acc = 0.0
        # for plot center lloss
        all_features, all_labels = [], []

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._val_loader),
                                             total=len(self._val_loader),
                                             leave=False):
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                # compute output, measure accuracy and record loss
                outputs, features = self._model(images)

                loss = self._criterion(outputs, targets)
                acc = accuracy(outputs, targets)[0]

                val_loss += loss.item()
                val_acc += acc.item()

                all_features.append(features.data.cpu().numpy())
                all_labels.append(targets.data.cpu().numpy())

            i += 1
            self._val_loss_list.append(val_loss / i)
            self._val_acc_list.append(val_acc / i)

        # plot center
        all_features = np.concatenate(all_features, 0)
        all_labels = np.concatenate(all_labels, 0)
        self._plot_features(all_features, all_labels, prefix="test")

    def _calc_acc_on_private_test(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")
        f = open("private_test_log.txt", "w")
        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._test_loader),
                                             total=len(self._test_loader),
                                             leave=False):

                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                print(outputs.shape, outputs)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()
                f.writelines("{}_{}\n".format(i, acc.item()))

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        f.close()
        return test_acc

    def _calc_acc_on_private_test_with_tta(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")
        f = open(
            "private_test_log_{}_{}.txt".format(self._configs["arch"],
                                                self._configs["model_name"]),
            "w",
        )

        with torch.no_grad():
            for idx in tqdm(range(len(self._test_set)),
                            total=len(self._test_set),
                            leave=False):
                images, targets = self._test_set[idx]
                targets = torch.LongTensor([targets])

                images = make_batch(images)
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs, _ = self._model(images)
                outputs = F.softmax(outputs, 1)

                # outputs.shape [tta_size, 7]
                outputs = torch.sum(outputs, 0)

                outputs = torch.unsqueeze(outputs, 0)
                # print(outputs.shape)
                # TODO: try with softmax first and see the change
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()
                f.writelines("{}_{}\n".format(idx, acc.item()))

            test_acc = test_acc / (idx + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        f.close()
        return test_acc

    def train(self):
        """make a training job"""
        print(self._model)

        try:
            while not self._is_stop():
                self._increase_epoch_num()
                self._train()
                self._val()

                self._update_training_state()
                self._logging()
        except KeyboardInterrupt:
            traceback.print_exc()
            pass

        # training stop
        try:
            # state = torch.load('saved/checkpoints/resatt18_rot30_2019Nov06_18.56')
            state = torch.load(self._checkpoint_path)
            if self._distributed:
                self._model.module.load_state_dict(state["net"])
            else:
                self._model.load_state_dict(state["net"])

            if not self._test_set.is_tta():
                self._test_acc = self._calc_acc_on_private_test()
            else:
                self._test_acc = self._calc_acc_on_private_test_with_tta()
            print(self._test_acc)
            self._save_weights()
        except Exception as e:
            traceback.print_exc()
            pass

        consume_time = str(datetime.datetime.now() - self._start_time)
        self._writer.add_text(
            "Summary",
            "Converged after {} epochs, consume {}".format(
                self._current_epoch_num, consume_time[:-7]),
        )
        self._writer.add_text(
            "Results",
            "Best validation accuracy: {:.3f}".format(self._best_val_acc))
        self._writer.add_text(
            "Results",
            "Best training accuracy: {:.3f}".format(self._best_train_acc))
        self._writer.add_text(
            "Results", "Private test accuracy: {:.3f}".format(self._test_acc))
        self._writer.close()

    def _update_training_state(self):
        if self._val_acc_list[-1] > self._best_val_acc:
            self._save_weights()
            self._plateau_count = 0
            self._best_val_acc = self._val_acc_list[-1]
            self._best_val_loss = self._val_loss_list[-1]
            self._best_train_acc = self._train_acc_list[-1]
            self._best_train_loss = self._train_loss_list[-1]
        else:
            self._plateau_count += 1

        # self._scheduler.step(self._train_loss_list[-1])
        # self._scheduler.step(100 - self._val_acc_list[-1])
        self._scheduler.step()

        if self._current_epoch_num > 20:
            self._center_scheduler.step()

    def _logging(self):
        consume_time = str(datetime.datetime.now() - self._start_time)

        message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._current_epoch_num,
            self._train_loss_list[-1],
            self._val_loss_list[-1],
            self._best_val_loss,
            self._train_acc_list[-1],
            self._val_acc_list[-1],
            self._best_val_acc,
            self._plateau_count,
            consume_time[:-7],
        )

        self._writer.add_scalar("Accuracy/Train", self._train_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Accuracy/Val", self._val_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/Train", self._train_loss_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/Val", self._val_loss_list[-1],
                                self._current_epoch_num)

        print(message)

    def _is_stop(self):
        """check stop condition"""
        return (self._plateau_count > self._max_plateau_count
                or self._current_epoch_num > self._max_epoch_num)

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _save_weights(self, test_acc=0.0):
        if self._distributed == 0:
            state_dict = self._model.state_dict()
        else:
            state_dict = self._model.module.state_dict()

        state = {
            **self._configs,
            "net": state_dict,
            "best_val_loss": self._best_val_loss,
            "best_val_acc": self._best_val_acc,
            "best_train_loss": self._best_train_loss,
            "best_train_acc": self._best_train_acc,
            "train_losses": self._train_loss_list,
            "val_loss_list": self._val_loss_list,
            "train_acc_list": self._train_acc_list,
            "val_acc_list": self._val_acc_list,
            "test_acc": self._test_acc,
        }

        torch.save(state, self._checkpoint_path)

    def _plot_features(self, features, labels, prefix):
        """Plot features on 2D plane.
        Args:
            features: (num_instances, num_features).
            labels: (num_instances). 
        """
        colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C6"]
        for label_idx in range(7):
            plt.scatter(
                features[labels == label_idx, 0],
                features[labels == label_idx, 1],
                c=colors[label_idx],
                s=1,
            )
        plt.legend(["0", "1", "2", "3", "4", "5", "6"], loc="upper right")
        plt_dirname = os.path.join("saved/plot/{}".format(
            os.path.basename(self._checkpoint_path)))
        if not os.path.exists(plt_dirname):
            os.makedirs(plt_dirname, exist_ok=True)

        save_name = os.path.join(
            plt_dirname, "epoch_{}_{}.png".format(self._current_epoch_num,
                                                  prefix))
        plt.savefig(save_name, bbox_inches="tight")
        plt.close()
Example #4
0
def train(
        train_data,
        exp_dir=datetime.now().strftime("corrector_model/%Y-%m-%d_%H%M"),
        learning_rate=0.00005,
        rsize=10,
        epochs=1,
        checkpoint_path='',
        seed=6548,
        batch_size=4,
        edge_loss=False,
        model_type='cnet',
        model_cap='normal',
        optimizer_type='radam',
        reset_optimizer=True,  # if true, does not load optimizer chekcpoints
        safe_descent=True,
        activation_type='mish',
        activation_args={},
        io=None,
        dynamic_lr=True,
        dropout=0,
        rotations=False,
        use_batch_norm=True,
        batch_norm_momentum=None,
        batch_norm_affine=True,
        use_gc=True,
        no_lr_schedule=False,
        diff_features_only=False):

    start_time = time.time()

    io.cprint("-------------------------------------------------------" +
              "\nexport dir = " + '/checkpoints/' + exp_dir +
              "\nbase_learning_rate = " + str(learning_rate) +
              "\nuse_batch_norm = " + str(use_batch_norm) +
              "\nbatch_norm_momentum = " + str(batch_norm_momentum) +
              "\nbatch_norm_affine = " + str(batch_norm_affine) +
              "\nno_lr_schedule = " + str(no_lr_schedule) + "\nuse_gc = " +
              str(use_gc) + "\nrsize = " + str(rsize) + "\npython_version: " +
              sys.version + "\ntorch_version: " + torch.__version__ +
              "\nnumpy_version: " + np.version.version + "\nmodel_type: " +
              model_type + "\nmodel_cap: " + model_cap + "\noptimizer: " +
              optimizer_type + "\nactivation_type: " + activation_type +
              "\nsafe_descent: " + str(safe_descent) + "\ndynamic_lr: " +
              str(dynamic_lr) + "\nrotations: " + str(rotations) +
              "\nepochs = " + str(epochs) +
              (("\ncheckpoint = " + checkpoint_path) if
               (checkpoint_path != None and checkpoint_path != '') else '') +
              "\nseed = " + str(seed) + "\nbatch_size = " + str(batch_size) +
              "\n#train_data = " +
              str(sum([bin.size(0) for bin in train_data["train_bins"]])) +
              "\n#test_data = " + str(len(train_data["test_samples"])) +
              "\n#validation_data = " + str(len(train_data["val_samples"])) +
              "\nedge_loss = " + str(edge_loss) +
              "\n-------------------------------------------------------" +
              "\nstart_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")

    # initialize torch & cuda ---------------------------------------------------------------------

    torch.manual_seed(seed)
    np.random.seed(seed)

    device = utils.getDevice(io)

    # extract train- & test data (and move to device) --------------------------------------------

    # train_bins = [bin.float().to(device) for bin in train_data["train_bins"]]
    # test_samples = [sample.float().to(device) for sample in train_data["test_samples"]]
    # val_samples = [sample.float().to(device) for sample in train_data["val_samples"]]

    train_bins = [bin.float() for bin in train_data["train_bins"]]
    test_samples = [sample.float() for sample in train_data["test_samples"]]
    val_samples = [sample.float() for sample in train_data["val_samples"]]

    # Initialize Model ------------------------------------------------------------------------------

    model_args = {
        'model_type': model_type,
        'model_cap': model_cap,
        'input_channels': test_samples[0].size(1),
        'output_channels': test_samples[0].size(1),
        'rsize': rsize,
        'emb_dims': 1024,
        'activation_type': activation_type,
        'activation_args': activation_args,
        'dropout': dropout,
        'batch_norm': use_batch_norm,
        'batch_norm_affine': batch_norm_affine,
        'batch_norm_momentum': batch_norm_momentum,
        'diff_features_only': diff_features_only
    }

    model = getModel(model_args).to(device)

    # init optimizer & scheduler -------------------------------------------------------------------

    lookahead_sync_period = 6

    optimizer = None
    if optimizer_type == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=learning_rate,
                          betas=(0.9, 0.999),
                          eps=1e-8,
                          use_gc=use_gc)
    elif optimizer_type == 'lookahead':
        optimizer = Ranger(model.parameters(),
                           lr=learning_rate,
                           alpha=0.9,
                           k=lookahead_sync_period)

    # make sure that either a LR schedule is given or dynamic LR is enabled
    assert dynamic_lr or not no_lr_schedule

    scheduler = None if no_lr_schedule else MultiplicativeLR(
        optimizer, lr_lambda=MultiplicativeAnnealing(epochs))

    # set train settings & load previous model state ------------------------------------------------------------

    checkpoint = getEmptyCheckpoint()
    last_epoch = 0

    if (checkpoint_path != None and checkpoint_path != ''):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'][-1])
        if not reset_optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'][-1])
        last_epoch = len(checkpoint['model_state_dict'])
        print('> loaded checkpoint! (%d epochs)' % (last_epoch))

    checkpoint['train_settings'].append({
        'learning_rate':
        learning_rate,
        'scheduler':
        scheduler,
        'epochs':
        epochs,
        'seed':
        seed,
        'batch_size':
        batch_size,
        'edge_loss':
        edge_loss,
        'optimizer':
        optimizer_type,
        'safe_descent:':
        str(safe_descent),
        'dynamic_lr':
        str(dynamic_lr),
        'rotations':
        str(rotations),
        'train_data_count':
        sum([bin.size(0) for bin in train_data["train_bins"]]),
        'test_data_count':
        len(train_data["test_samples"]),
        'validation_data_count':
        len(train_data["val_samples"]),
        'model_args':
        model_args
    })

    # set up report interval (for logging) and batch size -------------------------------------------------------------------

    report_interval = 100
    loss_function = torch.nn.MSELoss(reduction='mean')

    # begin training ###########################################################################################################################

    io.cprint("\nBeginning Training..\n")

    for epoch in range(last_epoch + 1, last_epoch + epochs + 1):

        io.cprint(
            "Epoch: %d ------------------------------------------------------------------------------------------"
            % (epoch))
        io.cprint("Current LR: %.10f" % (optimizer.param_groups[0]['lr']))

        model.train()
        optimizer.zero_grad()

        checkpoint['train_batch_loss'].append([])
        checkpoint['train_batch_N'].append([])
        checkpoint['train_batch_lr_adjust'].append([])
        checkpoint['train_batch_loss_reduction'].append([])
        checkpoint['lr'].append(optimizer.param_groups[0]['lr'])

        # draw random batches from random bins
        binbatches = utils.drawBinBatches([bin.size(0) for bin in train_bins],
                                          batchsize=batch_size)

        checkpoint['train_batch_N'][-1] = [
            train_bins[bin_id][batch_ids].size(1)
            for (bin_id, batch_ids) in binbatches
        ]

        failed_loss_optims = 0
        cum_lr_adjust_fac = 0
        cum_loss_reduction = 0

        # pre-compute random rotations if needed
        batch_rotations = [None] * len(binbatches)
        if rotations:
            start_rotations = time.time()
            batch_rotations = torch.zeros(
                (len(binbatches), batch_size, test_samples[0].size(1),
                 test_samples[0].size(1)),
                device=device)
            for i in range(len(binbatches)):
                for j in range(batch_size):
                    batch_rotations[i, j] = utils.getRandomRotation(
                        test_samples[0].size(1), device=device)
            print("created batch rotations (%ds)" %
                  (time.time() - start_rotations))

        b = 0  # batch counter

        train_start = time.time()

        for (bin_id, batch_ids) in binbatches:

            b += 1

            # print ("handling batch %d" % (b))

            # prediction & loss ----------------------------------------

            batch_sample = train_bins[bin_id][batch_ids].to(
                model.base.device)  # size: (B x N x d x 2)

            batch_loss = getBatchLoss(model,
                                      batch_sample,
                                      loss_function,
                                      edge_loss=edge_loss,
                                      rotations=batch_rotations[b - 1])
            batch_loss.backward()

            checkpoint['train_batch_loss'][-1].append(batch_loss.item())

            new_loss = 0.0
            lr_adjust = 1.0
            loss_reduction = 0.0

            # if safe descent is enabled, try to optimize the descent step so that a reduction in loss is guaranteed
            if safe_descent:

                # create backups to restore states before the optimizer step
                model_state_backup = copy.deepcopy(model.state_dict())
                opt_state_backup = copy.deepcopy(optimizer.state_dict())

                # make an optimizer step
                optimizer.step()

                # in each itearation, check if the optimzer gave an improvement
                # if not, restore the original states, reduce the learning rate and try again
                # no gradient needed for the plain loss calculation
                with torch.no_grad():
                    for i in range(10):

                        new_loss = getBatchLoss(
                            model,
                            batch_sample,
                            loss_function,
                            edge_loss=edge_loss,
                            rotations=batch_rotations[b - 1]).item()

                        # if the model performs better now we continue, if not we try a smaller learning step
                        if (new_loss < batch_loss.item()):
                            # print("lucky! (%f -> %f) reduction: %.4f%%" % (batch_loss.item(), new_loss, 100 * (batch_loss.item()-new_loss) / batch_loss.item()))
                            break
                        else:
                            # print("try again.. (%f -> %f)" % (batch_loss.item(), new_loss))
                            model.load_state_dict(model_state_backup)
                            optimizer.load_state_dict(opt_state_backup)
                            lr_adjust *= 0.7
                            optimizer.step(lr_adjust=lr_adjust)

                loss_reduction = 100 * (batch_loss.item() -
                                        new_loss) / batch_loss.item()

                if new_loss >= batch_loss.item():
                    failed_loss_optims += 1
                else:
                    cum_lr_adjust_fac += lr_adjust
                    cum_loss_reduction += loss_reduction

            else:

                cum_lr_adjust_fac += lr_adjust
                optimizer.step()

            checkpoint['train_batch_lr_adjust'][-1].append(lr_adjust)
            checkpoint['train_batch_loss_reduction'][-1].append(loss_reduction)

            # reset gradients
            optimizer.zero_grad()

            # statistic caluclation and output -------------------------

            if b % report_interval == 0:

                last_100_loss = sum(checkpoint['train_batch_loss'][-1]
                                    [b - report_interval:b]) / report_interval
                improvement_indicator = '+' if epoch > 1 and last_100_loss < checkpoint[
                    'train_loss'][-1] else ''

                io.cprint(
                    '  Batch %4d to %4d | loss: %.10f%1s | av. dist. per neighbor: %.10f | E%3d | T:%5ds | Failed Optims: %3d (%05.2f%%) | Av. Adjust LR: %.6f | Av. Loss Reduction: %07.4f%%'
                    % (b - (report_interval - 1), b, last_100_loss,
                       improvement_indicator, np.sqrt(last_100_loss), epoch,
                       time.time() - train_start, failed_loss_optims, 100 *
                       (failed_loss_optims / report_interval),
                       (cum_lr_adjust_fac /
                        (report_interval - failed_loss_optims)
                        if failed_loss_optims < report_interval else -1),
                       (cum_loss_reduction /
                        (report_interval - failed_loss_optims)
                        if failed_loss_optims < report_interval else -1)))

                failed_loss_optims = 0
                cum_lr_adjust_fac = 0
                cum_loss_reduction = 0

        checkpoint['train_loss'].append(
            sum(checkpoint['train_batch_loss'][-1]) / b)
        checkpoint['train_time'].append(time.time() - train_start)

        io.cprint(
            '----\n  TRN | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f'
            % (checkpoint['train_time'][-1], checkpoint['train_loss'][-1],
               np.sqrt(checkpoint['train_loss'][-1])))

        torch.cuda.empty_cache()

        ####################
        # Test & Validation
        ####################

        with torch.no_grad():

            if use_batch_norm:

                model.eval_bn()

                eval_bn_start = time.time()

                # run through all train samples again to accumulate layer-wise input distribution statistics (mean and variance) with fixed weights
                # these statistics are later used for the BatchNorm layers during inference
                for (bin_id, batch_ids) in binbatches:
                    input = train_bins[bin_id][batch_ids][:, :, :, 0].squeeze(
                        -1)  # size: (B x N x d)
                    model(input.transpose(1,
                                          2).to(model.base.device)).transpose(
                                              1, 2)  # size: (B x N x d)

                io.cprint('Accumulated BN Layer statistics (%ds)' %
                          (time.time() - eval_bn_start))

            model.eval()

            test_start = time.time()

            test_loss = getTestLoss(model,
                                    test_samples,
                                    loss_function,
                                    edge_loss=edge_loss)

            checkpoint['test_loss'].append(test_loss)
            checkpoint['test_time'].append(time.time() - test_start)

            io.cprint(
                '  TST | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f'
                % (checkpoint['test_time'][-1], checkpoint['test_loss'][-1],
                   np.sqrt(checkpoint['test_loss'][-1])))

            val_start = time.time()

            val_loss = getTestLoss(model,
                                   val_samples,
                                   loss_function,
                                   edge_loss=edge_loss)

            checkpoint['val_loss'].append(val_loss)
            checkpoint['val_time'].append(time.time() - val_start)

            io.cprint(
                '  VAL | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f'
                % (checkpoint['val_time'][-1], checkpoint['val_loss'][-1],
                   np.sqrt(checkpoint['val_loss'][-1])))

        ####################
        # Scheduler Step
        ####################

        if not no_lr_schedule:
            scheduler.step()

        if epoch > 1 and dynamic_lr and sum(
                checkpoint['train_batch_lr_adjust'][-1]) > 0:
            io.cprint("----\n  dynamic lr adjust: %.10f" %
                      (0.5 *
                       (1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                        len(checkpoint['train_batch_lr_adjust'][-1]))))
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5 * (
                    1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                    len(checkpoint['train_batch_lr_adjust'][-1]))

        # Save model and optimizer state ..
        checkpoint['model_state_dict'].append(copy.deepcopy(
            model.state_dict()))
        checkpoint['optimizer_state_dict'].append(
            copy.deepcopy(optimizer.state_dict()))

        torch.save(checkpoint, exp_dir + '/corrector_checkpoints.t7')

    io.cprint("\n-------------------------------------------------------" +
              ("\ntotal_time: %.2fh" % ((time.time() - start_time) / 3600)) +
              ("\ntrain_time: %.2fh" %
               (sum(checkpoint['train_time']) / 3600)) +
              ("\ntest_time: %.2fh" % (sum(checkpoint['test_time']) / 3600)) +
              ("\nval_time: %.2fh" % (sum(checkpoint['val_time']) / 3600)) +
              "\n-------------------------------------------------------" +
              "\nend_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")
Example #5
0
class FER2013Trainer:
    def __init__(self, model, train_set, val_set, test_set, configs):
        print("Start trainer..")
        print(configs)

        # load config
        self._configs = configs
        self._lr = self._configs["lr"]
        self._batch_size = self._configs["batch_size"]
        self._momentum = self._configs["momentum"]
        self._weight_decay = self._configs["weight_decay"]
        self._num_workers = self._configs["num_workers"]
        self._device = torch.device(self._configs["device"])
        self._max_epoch_num = self._configs["max_epoch_num"]
        self._max_plateau_count = self._configs["max_plateau_count"]

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._test_set = test_set
        self._model = model(
            in_channels=configs["in_channels"], num_classes=configs["num_classes"],
        )

        # self._model.fc = nn.Linear(512, 7)
        # self._model.fc = nn.Linear(256, 7)
        self._model = self._model.to(self._device)

        self._train_loader = DataLoader(
            self._train_set,
            batch_size=self._batch_size,
            num_workers=self._num_workers,
            pin_memory=True,
            shuffle=True,
        )
        self._val_loader = DataLoader(
            self._val_set,
            batch_size=self._batch_size,
            num_workers=self._num_workers,
            pin_memory=True,
            shuffle=False,
        )
        self._test_loader = DataLoader(
            self._test_set,
            batch_size=1,
            num_workers=self._num_workers,
            pin_memory=True,
            shuffle=False,
        )

        # define loss function (criterion) and optimizer
        class_weights = [
            1.02660468,
            9.40661861,
            1.00104606,
            0.56843877,
            0.84912748,
            1.29337298,
            0.82603942,
        ]
        class_weights = torch.FloatTensor(np.array(class_weights))

        if self._configs["weighted_loss"] == 0:
            self._criterion = nn.CrossEntropyLoss().to(self._device)
        else:
            self._criterion = nn.CrossEntropyLoss(class_weights).to(self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs["plateau_patience"],
            min_lr=self._configs["min_lr"],
            verbose=True,
        )

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs["cwd"],
            self._configs["log_dir"],
            "{}_{}_{}".format(
                self._configs["arch"],
                self._configs["model_name"],
                self._start_time.strftime("%Y%b%d_%H.%M"),
            ),
        )
        self._writer = SummaryWriter(log_dir)
        self._train_loss_list = []
        self._train_acc_list = []
        self._val_loss_list = []
        self._val_acc_list = []
        self._best_val_loss = 1e9
        self._best_val_acc = 0
        self._best_train_loss = 1e9
        self._best_train_acc = 0
        self._test_acc = 0.0
        self._plateau_count = 0
        self._current_epoch_num = 0

        # for checkpoints
        self._checkpoint_dir = os.path.join(self._configs["cwd"], "saved/checkpoints")
        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(self._checkpoint_dir, exist_ok=True)

        self._checkpoint_path = os.path.join(
            self._checkpoint_dir,
            "{}_{}_{}".format(
                self._configs["arch"],
                self._configs["model_name"],
                self._start_time.strftime("%Y%b%d_%H.%M"),
            ),
        )

    def queryCheckpointPath(self):
        return self._checkpoint_path

    def _train(self):
        self._model.train()
        train_loss = 0.0
        train_acc = 0.0
        last = 0
        for i, (images, targets) in tqdm(
            enumerate(self._train_loader), total=len(self._train_loader), leave=False
        ):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            acc = accuracy(outputs, targets)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            last = i

        self._train_loss_list.append(train_loss / (last + 1))
        self._train_acc_list.append(train_acc / (last + 1))

    def _val(self):
        self._model.eval()
        val_loss = 0.0
        val_acc = 0.0

        with torch.no_grad():
            for i, (images, targets) in tqdm(
                enumerate(self._val_loader), total=len(self._val_loader), leave=False
            ):
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                # compute output, measure accuracy and record loss
                outputs = self._model(images)

                loss = self._criterion(outputs, targets)
                acc = accuracy(outputs, targets)[0]

                val_loss += loss.item()
                val_acc += acc.item()

            i += 1
            self._val_loss_list.append(val_loss / i)
            self._val_acc_list.append(val_acc / i)

    def _calc_acc_on_private_test(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")
        with torch.no_grad():
            for i, (images, targets) in tqdm(
                enumerate(self._test_loader), total=len(self._test_loader), leave=False
            ):

                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                print(outputs.shape, outputs)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        return test_acc

    def _calc_acc_on_private_test_with_tta(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test with tta..")

        with torch.no_grad():
            for idx in tqdm(
                range(len(self._test_set)), total=len(self._test_set), leave=False
            ):
                images, targets = self._test_set[idx]
                targets = torch.LongTensor([targets])

                if not isinstance(images, list):
                    images = [images]
                images = torch.stack(images, 0)
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                outputs = F.softmax(outputs, 1)
                outputs = torch.sum(outputs, 0)
                outputs = torch.unsqueeze(outputs, 0)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()

            test_acc = test_acc / (idx + 1)
        print("Accuracy on private test with tta: {:.3f}".format(test_acc))
        return test_acc

    def train(self):
        print(self._model)
        while not self._is_stop():
            self._increase_epoch_num()
            self._train()
            self._val()
            self._update_training_state()
            self._logging()

        # training stop
        state = torch.load(self._checkpoint_path)
        self._model.load_state_dict(state["net"])

        if not self._test_set.is_tta():
            self._test_acc = self._calc_acc_on_private_test()
        else:
            self._test_acc = self._calc_acc_on_private_test_with_tta()

        self._save_weights()
        consume_time = str(datetime.datetime.now() - self._start_time)
        self._writer.add_text(
            "Summary",
            "Converged after {} epochs, consume {}".format(
                self._current_epoch_num, consume_time[:-7]
            ),
        )
        self._writer.add_text(
            "Results", "Best validation accuracy: {:.3f}".format(self._best_val_acc)
        )
        self._writer.add_text(
            "Results", "Best training accuracy: {:.3f}".format(self._best_train_acc)
        )
        self._writer.add_text(
            "Results", "Private test accuracy: {:.3f}".format(self._test_acc)
        )
        self._writer.close()

    def _update_training_state(self):
        if self._val_acc_list[-1] > self._best_val_acc:
            self._save_weights()
            self._plateau_count = 0
            self._best_val_acc = self._val_acc_list[-1]
            self._best_val_loss = self._val_loss_list[-1]
            self._best_train_acc = self._train_acc_list[-1]
            self._best_train_loss = self._train_loss_list[-1]
        else:
            self._plateau_count += 1
        self._scheduler.step(100 - self._val_acc_list[-1])

    def _logging(self):
        consume_time = str(datetime.datetime.now() - self._start_time)
        message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._current_epoch_num,
            self._train_loss_list[-1],
            self._val_loss_list[-1],
            self._best_val_loss,
            self._train_acc_list[-1],
            self._val_acc_list[-1],
            self._best_val_acc,
            self._plateau_count,
            consume_time[:-7],
        )
        self._writer.add_scalar(
            "Accuracy/Train", self._train_acc_list[-1], self._current_epoch_num
        )
        self._writer.add_scalar(
            "Accuracy/Val", self._val_acc_list[-1], self._current_epoch_num
        )
        self._writer.add_scalar(
            "Loss/Train", self._train_loss_list[-1], self._current_epoch_num
        )
        self._writer.add_scalar(
            "Loss/Val", self._val_loss_list[-1], self._current_epoch_num
        )

        print(message)

    def _is_stop(self):
        return (
            self._plateau_count > self._max_plateau_count
            or self._current_epoch_num > self._max_epoch_num
        )

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _save_weights(self, test_acc=0.0):
        state_dict = self._model.state_dict()
        state = {
            **self._configs,
            "net": state_dict,
            "best_val_loss": self._best_val_loss,
            "best_val_acc": self._best_val_acc,
            "best_train_loss": self._best_train_loss,
            "best_train_acc": self._best_train_acc,
            "train_losses": self._train_loss_list,
            "val_loss_list": self._val_loss_list,
            "train_acc_list": self._train_acc_list,
            "val_acc_list": self._val_acc_list,
            "test_acc": self._test_acc,
        }
        torch.save(state, self._checkpoint_path)
Example #6
0
def train(opt, logging):
    
    ## Data Prepare ##
    if opt.main_proc:
        logging.info("Building dataset")
                                           
    train_dataset = DeepSpeakerUttDataset(opt, os.path.join(opt.dataroot, 'train'))
    if not opt.distributed:
        train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)
    else:
        train_sampler = DistributedBucketingSampler(train_dataset, batch_size=opt.batch_size,
                                                    num_replicas=opt.num_gpus, rank=opt.local_rank)
    train_loader = DeepSpeakerUttDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler)
             
    val_dataset = DeepSpeakerTestDataset(opt, os.path.join(opt.dataroot, 'test'))
    val_loader = DeepSpeakerTestDataLoader(val_dataset, batch_size=1, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
    
    opt.in_size = train_dataset.in_size
    opt.out_size = train_dataset.class_nums  
    print('opt.in_size {} opt.out_size {}'.format(opt.in_size, opt.out_size))  
                                           
    if opt.main_proc:
        logging.info("Building dataset Sucessed")
    
    ##  Building Model ##
    if opt.main_proc:
        logging.info("Building Model")
    
    model = model_select(opt)
    margin = margin_select(opt)
    
    if opt.resume:
        model, opt.total_iters = load(model, opt.resume, 'state_dict')
        margin, opt.total_iters = load(margin, opt.resume, 'margin_state_dict')
    
    # define optimizers for different layer
    criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    if opt.optim_type == 'sgd':
        optimizer = optim.SGD([
            {'params': model.parameters(), 'weight_decay': 5e-4},
            {'params': margin.parameters(), 'weight_decay': 5e-4}
        ], lr=opt.lr, momentum=0.9, nesterov=True)
    elif opt.optim_type == 'adam':
        optimizer = optim.Adam([
            {'params': model.parameters(), 'weight_decay': 5e-4},
            {'params': margin.parameters(), 'weight_decay': 5e-4}
        ], lr=opt.lr, betas=(opt.beta1, 0.999))
    elif opt.optim_type == 'radam':
        optimizer = RAdam([
            {'params': model.parameters(), 'weight_decay': 5e-4},
            {'params': margin.parameters(), 'weight_decay': 5e-4}
        ], lr=opt.lr)
        
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.1)
        
    model.to(opt.device)
    margin.to(opt.device)
    
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.local_rank],
                                                          output_device=opt.local_rank)
        margin = torch.nn.parallel.DistributedDataParallel(margin, device_ids=[opt.local_rank],
                                                           output_device=opt.local_rank)
    if opt.main_proc:
        print(model)
        print(margin)
        logging.info("Building Model Sucessed") 
        
    best_perform_eer = 1.0
    
    losses = utils.AverageMeter()
    acc = utils.AverageMeter()

    # Initial performance
    if opt.main_proc:
        EER = evaluate(opt, model, val_loader, logging)
        best_perform_eer = EER
        print('>>Start performance: EER = {}<<'.format(best_perform_eer))
    
    total_iters = opt.total_iters
    for epoch in range(1, opt.total_epoch + 1):
        train_sampler.shuffle(epoch)
        scheduler.step()
        # train model
        if opt.main_proc:
            logging.info('Train Epoch: {}/{} ...'.format(epoch, opt.total_epoch))
        model.train()
        margin.train()

        since = time.time()
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, inputs, targets = data
            inputs, label = inputs.to(opt.device), targets.to(opt.device)
            optimizer.zero_grad()
            
            raw_logits, attn, w, b = model(inputs)
            output = margin(raw_logits, label)
            #loss = criterion(output, label)
            loss = cal_loss(output, label, criterion, smoothing=opt.smoothing)
            loss_dict_reduced = reduce_loss_dict(opt, {'loss': loss})
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_value = losses_reduced.item()
            
            # Check the loss and avoid the invaided loss
            inf = float("inf")
            if loss_value == inf or loss_value == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
                continue
                    
            loss.backward()
            if utils.check_grad(model.parameters(), opt.clip_grad, opt.ignore_grad):
                if opt.main_proc:
                    logging.info('Not a finite gradient or too big, ignoring')
                optimizer.zero_grad()
                continue
            optimizer.step()

            total_iters += opt.num_gpus
            losses.update(loss_value)
            
            # print train information
            if total_iters % opt.print_freq == 0 and opt.main_proc:
                # current training accuracy
                _, predict = torch.max(output.data, 1)
                total = label.size(0)
                correct = (np.array(predict.cpu()) == np.array(label.data.cpu())).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()
                logging.info("Iters: {:0>6d}/[{:0>2d}], loss: {:.4f} ({:.4f}), train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, epoch, loss_value, losses.avg, correct/total, time_cur, scheduler.get_lr()[0]))
              
            # save model
            if total_iters % opt.save_freq == 0 and opt.main_proc:
                logging.info('Saving checkpoint: {}'.format(total_iters))
                if opt.distributed:
                    model_state_dict = model.module.state_dict()
                    margin_state_dict = margin.module.state_dict()
                else:
                    model_state_dict = model.state_dict()
                    margin_state_dict = margin.state_dict()
                state = {'state_dict': model_state_dict, 'margin_state_dict': margin_state_dict, 'total_iters': total_iters,}
                filename = 'newest_model.pth'
                if os.path.isfile(os.path.join(opt.model_dir, filename)):
                    shutil.copy(os.path.join(opt.model_dir, filename), os.path.join(opt.model_dir, 'newest_model.pth_bak'))
                utils.save_checkpoint(state, opt.model_dir, filename=filename)
                    
            # Validate the trained model
            if total_iters % opt.validate_freq == 0:
                EER = evaluate(opt, model, val_loader, logging)
                ##scheduler.step(EER)
                
                if opt.main_proc and EER < best_perform_eer:
                    best_perform_eer = EER
                    logging.info("Found better validated model (EER = %.3f), saving to model_best.pth" % (best_perform_eer))
                    if opt.distributed:
                        model_state_dict = model.module.state_dict()
                        margin_state_dict = margin.module.state_dict()
                    else:
                        model_state_dict = model.state_dict()
                        margin_state_dict = margin.state_dict()
                    state = {'state_dict': model_state_dict, 'margin_state_dict': margin_state_dict, 'total_iters': total_iters,}  
                    filename = 'model_best.pth'
                    if os.path.isfile(os.path.join(opt.model_dir, filename)):
                        shutil.copy(os.path.join(opt.model_dir, filename), os.path.join(opt.model_dir, 'model_best.pth_bak'))                   
                    utils.save_checkpoint(state, opt.model_dir, filename=filename)

                model.train()
                margin.train()
                losses.reset()
                   
class TeeTrainer(Trainer):
    """for segmentation task
    """
    def __init__(self, model, train_set, val_set, configs):
        super().__init__()
        print("Start trainer..")
        # load config
        self._configs = configs
        self._lr = self._configs["lr"]
        self._batch_size = self._configs["batch_size"]
        self._momentum = self._configs["momentum"]
        self._weight_decay = self._configs["weight_decay"]
        self._distributed = self._configs["distributed"]
        self._num_workers = self._configs["num_workers"]
        self._device = torch.device(self._configs["device"])
        self._max_epoch_num = self._configs["max_epoch_num"]
        self._max_plateau_count = self._configs["max_plateau_count"]

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._model = model(
            in_channels=configs["in_channels"],
            num_classes=configs["num_classes"],
        )
        self._model.load_state_dict(
            torch.load("saved/checkpoints/mixed.test")["net"])

        print(self._configs)
        self._model = self._model.to(self._device)

        if self._distributed == 1:
            torch.distributed.init_process_group(backend="nccl")
            self._model = nn.parallel.DistributedDataParallel(
                self._model, find_unused_parameters=True)
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                worker_init_fn=lambda x: np.random.seed(x),
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )
        else:
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                # worker_init_fn=lambda x: np.random.seed(x)
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                # worker_init_fn=lambda x: np.random.seed(x)
            )

        # define loss function (criterion) and optimizer
        # class_weights = torch.FloatTensor(np.array([0.3, 0.7])).to(self._device)
        self._criterion = nn.CrossEntropyLoss().to(self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs["plateau_patience"],
            verbose=True)

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs["cwd"],
            self._configs["log_dir"],
            "{}_{}".format(self._configs["model_name"], str(self._start_time)),
        )

        self._writer = SummaryWriter(log_dir)
        self._train_loss = []
        self._train_acc = []
        self._val_loss = []
        self._val_acc = []
        self._best_loss = 1e9
        self._best_acc = 0
        self._plateau_count = 0
        self._current_epoch_num = 0

    def reset(self):
        """reset trainer"""
        pass

    def _train(self):
        self._model.train()
        train_loss = 0.0
        train_acc = 0.0

        for i, (images, targets) in tqdm(enumerate(self._train_loader),
                                         total=len(self._train_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            acc = accuracy(outputs, targets)[0]
            # acc = eval_metrics(targets, outputs, 2)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            # log
            if i == 0:
                grid = torchvision.utils.make_grid(images)
                self._writer.add_image("images", grid, 0)
                # self._writer.add_graph(self._model, images)
                # self._writer.close()

            if self._configs["little"] == 1:
                mask = torch.squeeze(outputs, 0)
                mask = mask.detach().cpu().numpy() * 255
                mask = np.transpose(mask, (1, 2, 0)).astype(np.uint8)
                cv2.imwrite(
                    os.path.join("debug",
                                 "e{}.png".format(self._current_epoch_num)),
                    mask[..., 1],
                )

        i += 1
        self._train_loss.append(train_loss / i)
        self._train_acc.append(train_acc / i)

    def _val(self):
        self._model.eval()
        val_loss = 0.0
        val_acc = 0.0

        os.system("rm -rf debug/*")
        for i, (images, targets) in tqdm(enumerate(self._val_loader),
                                         total=len(self._val_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            acc = accuracy(outputs, targets)[0]
            # acc = eval_metrics(targets, outputs, 2)[0]

            val_loss += loss.item()
            val_acc += acc.item()

            # debug time
            outputs = torch.squeeze(outputs, dim=0)
            outputs = torch.argmax(outputs, dim=0)
            tmp_image = torch.squeeze(images, dim=0)
            print(tmp_image.shape)
            tmp_image = tmp_image.cpu().numpy()
            cv2.imwrite("debug/{}/{}.png".format(outputs, i), tmp_image)

        i += 1
        self._val_loss.append(val_loss / i)
        self._val_acc.append(val_acc / i)

    def train(self):
        """make a training job"""
        while not self._is_stop():
            self._train()
            self._val()

            self._update_training_state()
            self._logging()
            self._increase_epoch_num()

        self._writer.close()  # be careful with this line of code

    def _update_training_state(self):
        if self._val_acc[-1] > self._best_acc:
            self._save_weights()
            self._plateau_count = 0
            self._best_acc = self._val_acc[-1]
            self._best_loss = self._val_loss[-1]
        else:
            self._plateau_count += 1
        self._scheduler.step(self._val_loss[-1])

    def _logging(self):
        # TODO: save message to log file, tensorboard then
        consume_time = str(datetime.datetime.now() - self._start_time)

        message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._current_epoch_num,
            self._train_loss[-1],
            self._val_loss[-1],
            self._best_loss,
            self._train_acc[-1],
            self._val_acc[-1],
            self._best_acc,
            self._plateau_count,
            consume_time[:-7],
        )

        self._writer.add_scalar("Accuracy/train", self._train_acc[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Accuracy/val", self._val_acc[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/train", self._train_loss[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/val", self._val_loss[-1],
                                self._current_epoch_num)

        print(message)

    def _is_stop(self):
        """check stop condition"""
        return (self._plateau_count > self._max_plateau_count
                or self._current_epoch_num > self._max_epoch_num)

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _store_trainer(self):
        """store config, training info and traning result to file"""
        pass

    def _save_weights(self):
        """save checkpoint"""
        if self._distributed == 0:
            state_dict = self._model.state_dict()
        else:
            state_dict = self._model.module.state_dict()
        state = {
            **self._configs,
            "net": state_dict,
            "best_loss": self._best_loss,
            "best_acc": self._best_acc,
        }

        checkpoint_dir = os.path.join(self._configs["cwd"],
                                      "saved/checkpoints")

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir, exist_ok=True)

        torch.save(state,
                   os.path.join(checkpoint_dir, self._configs["model_name"]))
Example #8
0
        if i == 1:
            timer.start()

        lr = adjust_lr_iter(cfg, optimizer, i)

        img = img.cuda().detach()
        target = label.cuda().detach()

        with timer.counter('forward'):
            output = model(img)

        with timer.counter('loss'):
            loss = criterion(output, target)

        with timer.counter('backward'):
            optimizer.zero_grad()
            loss.backward()

        with timer.counter('update'):
            optimizer.step()

        time_this = time.time()
        if i > 0:
            batch_time = time_this - time_last
            timer.add_batch_time(batch_time)
        time_last = time_this

        if i > 0 and i % 10 == 0:
            time_name = [
                'batch', 'data', 'forward', 'loss', 'backward', 'update'
            ]
Example #9
0
class FER2013Trainer(Trainer):
    """for classification task"""
    def __init__(self, model, train_set, val_set, test_set, configs):
        super().__init__()
        print("Start trainer..")
        print(configs)

        # load config
        self._configs = configs
        self._lr = self._configs["lr"]
        self._batch_size = self._configs["batch_size"]
        self._momentum = self._configs["momentum"]
        self._weight_decay = self._configs["weight_decay"]
        self._distributed = self._configs["distributed"]
        self._num_workers = self._configs["num_workers"]
        self._device = torch.device(self._configs["device"])
        self._max_epoch_num = self._configs["max_epoch_num"]
        self._max_plateau_count = self._configs["max_plateau_count"]

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._test_set = test_set
        self._model = model(
            in_channels=configs["in_channels"],
            num_classes=configs["num_classes"],
        )

        self._model.fc = nn.Linear(512, 7)
        self._model = self._model.to(self._device)

        if self._distributed == 1:
            torch.distributed.init_process_group(backend="nccl")
            self._model = nn.parallel.DistributedDataParallel(
                self._model, find_unused_parameters=True)
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                worker_init_fn=lambda x: np.random.seed(x),
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )

            self._test_loader = DataLoader(
                self._test_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x),
            )
        else:
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )
            self._test_loader = DataLoader(
                self._test_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )

        # define loss function (criterion) and optimizer
        class_weights = [
            1.02660468,
            9.40661861,
            1.00104606,
            0.56843877,
            0.84912748,
            1.29337298,
            0.82603942,
        ]
        class_weights = torch.FloatTensor(np.array(class_weights))

        if self._configs["weighted_loss"] == 0:
            self._criterion = nn.CrossEntropyLoss().to(self._device)
        else:
            self._criterion = nn.CrossEntropyLoss(class_weights).to(
                self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs["plateau_patience"],
            min_lr=1e-6,
            verbose=True,
        )

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs["cwd"],
            self._configs["log_dir"],
            "{}_{}".format(self._configs["model_name"],
                           self._start_time.strftime("%Y%b%d_%H.%M")),
        )

        self._writer = SummaryWriter(log_dir)
        self._train_loss = []
        self._train_acc = []
        self._val_loss = []
        self._val_acc = []
        self._best_loss = 1e9
        self._best_acc = 0
        self._test_acc = 0.0
        self._plateau_count = 0
        self._current_epoch_num = 0

        # for checkpoints
        self._checkpoint_dir = os.path.join(self._configs["cwd"],
                                            "saved/checkpoints")
        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(checkpoint_dir, exist_ok=True)

        self._checkpoint_path = os.path.join(
            self._checkpoint_dir,
            "{}_{}".format(self._configs["model_name"],
                           self._start_time.strftime("%Y%b%d_%H.%M")),
        )

    def _train(self):
        self._model.train()
        train_loss = 0.0
        train_acc = 0.0

        for i, (images, targets) in tqdm(enumerate(self._train_loader),
                                         total=len(self._train_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            acc = accuracy(outputs, targets)[0]
            # acc = eval_metrics(targets, outputs, 2)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

        i += 1
        self._train_loss.append(train_loss / i)
        self._train_acc.append(train_acc / i)

    def _val(self):
        self._model.eval()
        val_loss = 0.0
        val_acc = 0.0

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._val_loader),
                                             total=len(self._val_loader),
                                             leave=False):
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                # compute output, measure accuracy and record loss
                outputs = self._model(images)

                loss = self._criterion(outputs, targets)
                acc = accuracy(outputs, targets)[0]

                val_loss += loss.item()
                val_acc += acc.item()

            i += 1
            self._val_loss.append(val_loss / i)
            self._val_acc.append(val_acc / i)

    def _calc_acc_on_private_test(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._test_loader),
                                             total=len(self._test_loader),
                                             leave=False):

                # TODO: implement augment when predict
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        return test_acc

    def _calc_acc_on_private_test_with_tta(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")

        transform = transforms.Compose([
            transforms.ToPILImage(),
        ])

        for idx in len(self._test_set):
            image, label = self._test_set[idx]

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._test_loader),
                                             total=len(self._test_loader),
                                             leave=False):

                # TODO: implement augment when predict
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        return test_acc

    def _calc_acc_on_private_test(self):
        self._model.eval()
        test_acc = 0.0
        print("Calc acc on private test..")

        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._test_loader),
                                             total=len(self._test_loader),
                                             leave=False):

                # TODO: implement augment when predict
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        return test_acc

    def train(self):
        """make a training job"""
        print(self._model)
        while not self._is_stop():
            self._increase_epoch_num()
            self._train()
            self._val()

            self._update_training_state()
            self._logging()

        # training stop
        try:
            state = torch.load(self._checkpoint_path)
            if self._distributed:
                self._model.module.load_state_dict(state["net"])
            else:
                self._model.load_state_dict(state["net"])
            # self._test_acc = self._calc_acc_on_private_test()
            test_acc = self._calc_acc_on_private_test_with_tta()
            self._save_weights()
        except Exception as e:
            print("Testing error when training stop")
            print(e)

        self._writer.add_text(
            "Summary",
            "Converged after {} epochs".format(self._current_epoch_num))
        self._writer.add_text(
            "Summary",
            "Best validation accuracy: {:.3f}".format(self._current_epoch_num),
        )
        self._writer.add_text(
            "Summary", "Private test accuracy: {:.3f}".format(self._test_acc))
        self._writer.close()

    def _update_training_state(self):
        if self._val_acc[-1] > self._best_acc:
            self._save_weights()
            self._plateau_count = 0
            self._best_acc = self._val_acc[-1]
            self._best_loss = self._val_loss[-1]
        else:
            self._plateau_count += 1

        self._scheduler.step(100 - self._val_acc[-1])

    def _logging(self):
        consume_time = str(datetime.datetime.now() - self._start_time)

        message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._current_epoch_num,
            self._train_loss[-1],
            self._val_loss[-1],
            self._best_loss,
            self._train_acc[-1],
            self._val_acc[-1],
            self._best_acc,
            self._plateau_count,
            consume_time[:-7],
        )

        self._writer.add_scalar("Accuracy/Train", self._train_acc[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Accuracy/Val", self._val_acc[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/Train", self._train_loss[-1],
                                self._current_epoch_num)
        self._writer.add_scalar("Loss/Val", self._val_loss[-1],
                                self._current_epoch_num)

        print(message)

    def _is_stop(self):
        """check stop condition"""
        return (self._plateau_count > self._max_plateau_count
                or self._current_epoch_num > self._max_epoch_num)

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _save_weights(self, test_acc=0.0):
        if self._distributed == 0:
            state_dict = self._model.state_dict()
        else:
            state_dict = self._model.module.state_dict()

        state = {
            **self._configs,
            "net": state_dict,
            "best_loss": self._best_loss,
            "best_acc": self._best_acc,
            "train_losses": self._train_loss,
            "val_loss": self._val_loss,
            "train_acc": self._train_acc,
            "val_acc": self._val_acc,
            "test_acc": self._test_acc,
        }

        torch.save(state, self._checkpoint_path)
Example #10
0
def train(train_data,
          exp_dir=datetime.now().strftime("detector_model/%Y-%m-%d_%H%M"),
          learning_rate=0.00005,
          rsize=10,
          epochs=1,
          checkpoint_path='',
          seed=6548,
          batch_size=4,
          model_type='cnet',
          model_cap='normal',
          optimizer='radam',
          safe_descent=True,
          activation_type='mish',
          activation_args={},
          io=None,
          dynamic_lr=True,
          dropout=0,
          rotations=False,
          use_batch_norm=True,
          batch_norm_momentum=None,
          batch_norm_affine=True,
          use_gc=True,
          no_lr_schedule=False,
          diff_features_only=False,
          scale_min=1,
          scale_max=1,
          noise=0):

    start_time = time.time()

    scale_min = scale_min if scale_min < 1 else 1
    scale_max = scale_max if scale_max > 1 else 1

    io.cprint("-------------------------------------------------------" +
              "\nexport dir = " + '/checkpoints/' + exp_dir +
              "\nbase_learning_rate = " + str(learning_rate) +
              "\nuse_batch_norm = " + str(use_batch_norm) +
              "\nbatch_norm_momentum = " + str(batch_norm_momentum) +
              "\nbatch_norm_affine = " + str(batch_norm_affine) +
              "\nno_lr_schedule = " + str(no_lr_schedule) + "\nuse_gc = " +
              str(use_gc) + "\nrsize = " + str(rsize) + "\npython_version: " +
              sys.version + "\ntorch_version: " + torch.__version__ +
              "\nnumpy_version: " + np.version.version + "\nmodel_type: " +
              model_type + "\nmodel_cap: " + model_cap + "\noptimizer: " +
              optimizer + "\nactivation_type: " + activation_type +
              "\nsafe_descent: " + str(safe_descent) + "\ndynamic_lr: " +
              str(dynamic_lr) + "\nrotations: " + str(rotations) +
              "\nscaling: " + str(scale_min) + " to " + str(scale_max) +
              "\nnoise: " + str(noise) + "\nepochs = " + str(epochs) +
              (("\ncheckpoint = " +
                checkpoint_path) if checkpoint_path != '' else '') +
              "\nseed = " + str(seed) + "\nbatch_size = " + str(batch_size) +
              "\n#train_data = " +
              str(sum([bin.size(0) for bin in train_data["train_bins"]])) +
              "\n#test_data = " + str(len(train_data["test_samples"])) +
              "\n#validation_data = " + str(len(train_data["val_samples"])) +
              "\n-------------------------------------------------------" +
              "\nstart_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")

    # initialize torch & cuda ---------------------------------------------------------------------

    torch.manual_seed(seed)
    np.random.seed(seed)

    device = utils.getDevice(io)

    # extract train- & test data (and move to device) --------------------------------------------

    pts = train_data["pts"].to(device)
    val_pts = train_data["val_pts"].to(device)

    train_bins = train_data["train_bins"]
    test_samples = train_data["test_samples"]
    val_samples = train_data["val_samples"]

    # the maximum noise offset for each point is equal to the distance to its nearest neighbor
    max_noise = torch.square(pts[train_data["knn"][:, 0]] -
                             pts).sum(dim=1).sqrt()

    # Initialize Model ------------------------------------------------------------------------------

    model_args = {
        'model_type': model_type,
        'model_cap': model_cap,
        'input_channels': pts.size(1),
        'output_channels': 2,
        'rsize': rsize,
        'emb_dims': 1024,
        'activation_type': activation_type,
        'activation_args': activation_args,
        'dropout': dropout,
        'batch_norm': use_batch_norm,
        'batch_norm_affine': batch_norm_affine,
        'batch_norm_momentum': batch_norm_momentum,
        'diff_features_only': diff_features_only
    }

    model = getModel(model_args).to(device)

    # init optimizer & scheduler -------------------------------------------------------------------

    lookahead_sync_period = 6

    opt = None
    if optimizer == 'radam':
        opt = RAdam(model.parameters(),
                    lr=learning_rate,
                    betas=(0.9, 0.999),
                    eps=1e-8,
                    use_gc=use_gc)
    elif optimizer == 'lookahead':
        opt = Ranger(model.parameters(),
                     lr=learning_rate,
                     alpha=0.9,
                     k=lookahead_sync_period)

    # make sure that either a LR schedule is given or dynamic LR is enabled
    assert dynamic_lr or not no_lr_schedule

    scheduler = None if no_lr_schedule else MultiplicativeLR(
        opt, lr_lambda=MultiplicativeAnnealing(epochs))

    # set train settings & load previous model state ------------------------------------------------------------

    checkpoint = getEmptyCheckpoint()
    last_epoch = 0

    if (checkpoint_path != ''):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'][-1])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'][-1])
        last_epoch = len(checkpoint['model_state_dict'])
        print('> loaded checkpoint! (%d epochs)' % (last_epoch))

    checkpoint['train_settings'].append({
        'learning_rate':
        learning_rate,
        'scheduler':
        scheduler,
        'epochs':
        epochs,
        'seed':
        seed,
        'batch_size':
        batch_size,
        'optimizer':
        optimizer,
        'safe_descent:':
        str(safe_descent),
        'dynamic_lr':
        str(dynamic_lr),
        'rotations':
        str(rotations),
        'scale_min':
        scale_min,
        'scale_max':
        scale_max,
        'noise':
        noise,
        'train_data_count':
        sum([bin.size(0) for bin in train_data["train_bins"]]),
        'test_data_count':
        len(train_data["test_samples"]),
        'validation_data_count':
        len(train_data["val_samples"]),
        'model_args':
        model_args
    })

    # calculate class weights ---------------------------------------------------------------------

    av_c1_freq = sum([
        torch.sum(bin[:, :, 1]).item() for bin in train_data["train_bins"]
    ]) / sum([bin[:, :, 1].numel() for bin in train_data["train_bins"]])
    class_weights = torch.tensor([av_c1_freq,
                                  1 - av_c1_freq]).float().to(device)

    io.cprint("\nC0 Weight: %.4f" % (class_weights[0].item()))
    io.cprint("C1 Weight: %.4f" % (class_weights[1].item()))

    # Adjust Weights in favor of C1 (edge:true class)
    # class_weights[0] = class_weights[0] / 2
    # class_weights[1] = 1 - class_weights[0]
    # io.cprint("\nAdjusted C0 Weight: %.4f" % (class_weights[0].item()))
    # io.cprint("Adjusted C1 Weight: %.4f" % (class_weights[1].item()))

    # set up report interval (for logging) and batch size -------------------------------------------------------------------

    report_interval = 100

    # begin training ###########################################################################################################################

    io.cprint("\nBeginning Training..\n")

    for epoch in range(last_epoch + 1, last_epoch + epochs + 1):

        io.cprint(
            "Epoch: %d ------------------------------------------------------------------------------------------"
            % (epoch))
        io.cprint("Current LR: %.10f" % (opt.param_groups[0]['lr']))

        model.train()
        opt.zero_grad()

        checkpoint['train_batch_loss'].append([])
        checkpoint['train_batch_N'].append([])
        checkpoint['train_batch_acc'].append([])
        checkpoint['train_batch_C0_acc'].append([])
        checkpoint['train_batch_C1_acc'].append([])
        checkpoint['train_batch_lr_adjust'].append([])
        checkpoint['train_batch_loss_reduction'].append([])
        checkpoint['lr'].append(opt.param_groups[0]['lr'])

        # draw random batches from random bins
        binbatches = utils.drawBinBatches([bin.size(0) for bin in train_bins],
                                          batchsize=batch_size)

        checkpoint['train_batch_N'][-1] = [
            train_bins[bin_id][batch_ids].size(1)
            for (bin_id, batch_ids) in binbatches
        ]

        failed_loss_optims = 0
        cum_lr_adjust_fac = 0
        cum_loss_reduction = 0

        # pre-compute random rotations if needed
        batch_rotations = [None] * len(binbatches)
        if rotations:
            start_rotations = time.time()
            batch_rotations = torch.zeros(
                (len(binbatches), batch_size, pts.size(1), pts.size(1)),
                device=device)
            for i in range(len(binbatches)):
                for j in range(batch_size):
                    batch_rotations[i,
                                    j] = utils.getRandomRotation(pts.size(1),
                                                                 device=device)
            print("created batch rotations (%ds)" %
                  (time.time() - start_rotations))

        b = 0  # batch counter

        train_start = time.time()

        for (bin_id, batch_ids) in binbatches:

            b += 1

            batch_pts_ids = train_bins[bin_id][batch_ids][:, :,
                                                          0]  # size: (B x N)
            batch_input = pts[batch_pts_ids]  # size: (B x N x d)
            batch_target = train_bins[bin_id][batch_ids][:, :, 1].to(
                device)  # size: (B x N)

            if batch_rotations[b - 1] != None:
                batch_input = batch_input.matmul(batch_rotations[b - 1])

            if noise > 0:
                noise_v = torch.randn(
                    batch_input.size(),
                    device=batch_input.device)  # size: (B x N x d)
                noise_v.div_(
                    torch.square(noise_v).sum(
                        dim=2).sqrt()[:, :, None])  # norm to unit vectors
                batch_input.addcmul(noise_v,
                                    max_noise[batch_pts_ids][:, :, None],
                                    value=noise)

            if scale_min < 1 or scale_max > 1:
                # batch_scales = scale_min + torch.rand(batch_input.size(0), device=batch_input.device) * (scale_max - scale_min)
                batch_scales = torch.rand(batch_input.size(0),
                                          device=batch_input.device)
                batch_scales.mul_(scale_max - scale_min)
                batch_scales.add_(scale_min)
                batch_input.mul(batch_scales[:, None, None])

            batch_input = batch_input.transpose(1, 2)  # size: (B x d x N)

            # prediction & loss ----------------------------------------

            batch_prediction = model(batch_input).transpose(
                1, 2)  # size: (B x N x 2)
            batch_loss = cross_entropy(batch_prediction.reshape(-1, 2),
                                       batch_target.view(-1),
                                       class_weights,
                                       reduction='mean')
            batch_loss.backward()

            checkpoint['train_batch_loss'][-1].append(batch_loss.item())

            new_loss = 0.0
            lr_adjust = 1.0
            loss_reduction = 0.0

            # if safe descent is enabled, try to optimize the descent step so that a reduction in loss is guaranteed
            if safe_descent:

                # create backups to restore states before the optimizer step
                model_state_backup = copy.deepcopy(model.state_dict())
                opt_state_backup = copy.deepcopy(opt.state_dict())

                # make an optimizer step
                opt.step()

                # in each itearation, check if the optimzer gave an improvement
                # if not, restore the original states, reduce the learning rate and try again
                # no gradient needed for the plain loss calculation
                with torch.no_grad():
                    for i in range(10):

                        # new_batch_prediction = model(batch_input).transpose(1,2).contiguous()
                        new_batch_prediction = model(batch_input).transpose(
                            1, 2)
                        new_loss = cross_entropy(new_batch_prediction.reshape(
                            -1, 2),
                                                 batch_target.view(-1),
                                                 class_weights,
                                                 reduction='mean').item()

                        # if the model performs better now we continue, if not we try a smaller learning step
                        if (new_loss < batch_loss.item()):
                            # print("lucky! (%f -> %f) reduction: %.4f%%" % (batch_loss.item(), new_loss, 100 * (batch_loss.item()-new_loss) / batch_loss.item()))
                            break
                        else:
                            # print("try again.. (%f -> %f)" % (batch_loss.item(), new_loss))
                            model.load_state_dict(model_state_backup)
                            opt.load_state_dict(opt_state_backup)
                            lr_adjust *= 0.7
                            opt.step(lr_adjust=lr_adjust)

                loss_reduction = 100 * (batch_loss.item() -
                                        new_loss) / batch_loss.item()

                if new_loss >= batch_loss.item():
                    failed_loss_optims += 1
                else:
                    cum_lr_adjust_fac += lr_adjust
                    cum_loss_reduction += loss_reduction

            else:

                cum_lr_adjust_fac += lr_adjust
                opt.step()

            checkpoint['train_batch_lr_adjust'][-1].append(lr_adjust)
            checkpoint['train_batch_loss_reduction'][-1].append(loss_reduction)

            # reset gradients
            opt.zero_grad()

            # make class prediction and save stats -----------------------

            success_vector = torch.argmax(batch_prediction,
                                          dim=2) == batch_target

            c0_idx = batch_target == 0
            c1_idx = batch_target == 1

            checkpoint['train_batch_acc'][-1].append(
                torch.sum(success_vector).item() / success_vector.numel())
            checkpoint['train_batch_C0_acc'][-1].append(
                torch.sum(success_vector[c0_idx]).item() /
                torch.sum(c0_idx).item())  # TODO handle divsion by zero
            checkpoint['train_batch_C1_acc'][-1].append(
                torch.sum(success_vector[c1_idx]).item() /
                torch.sum(c1_idx).item())  # TODO

            # statistic caluclation and output -------------------------

            if b % report_interval == 0:

                last_100_loss = sum(checkpoint['train_batch_loss'][-1]
                                    [b - report_interval:b]) / report_interval
                last_100_acc = sum(checkpoint['train_batch_acc'][-1]
                                   [b - report_interval:b]) / report_interval
                last_100_acc_c0 = sum(
                    checkpoint['train_batch_C0_acc'][-1]
                    [b - report_interval:b]) / report_interval
                last_100_acc_c1 = sum(
                    checkpoint['train_batch_C1_acc'][-1]
                    [b - report_interval:b]) / report_interval

                io.cprint(
                    '  Batch %4d to %4d | loss: %.5f%1s| acc: %.4f%1s| C0 acc: %.4f%1s| C1 acc: %.4f%1s| E%3d | T:%5ds | Failed Optims: %3d (%05.2f%%) | Av. Adjust LR: %.6f | Av. Loss Reduction: %07.4f%%'
                    %
                    (b -
                     (report_interval - 1), b, last_100_loss, '+' if epoch > 1
                     and last_100_loss < checkpoint['train_loss'][-1] else '',
                     last_100_acc, '+' if epoch > 1
                     and last_100_acc > checkpoint['train_acc'][-1] else '',
                     last_100_acc_c0, '+' if epoch > 1
                     and last_100_acc_c0 > checkpoint['train_C0_acc'][-1] else
                     '', last_100_acc_c1, '+' if epoch > 1 and last_100_acc_c1
                     > checkpoint['train_C1_acc'][-1] else '', epoch,
                     time.time() - train_start, failed_loss_optims, 100 *
                     (failed_loss_optims / report_interval),
                     (cum_lr_adjust_fac /
                      (report_interval - failed_loss_optims)
                      if failed_loss_optims < report_interval else -1),
                     (cum_loss_reduction /
                      (report_interval - failed_loss_optims)
                      if failed_loss_optims < report_interval else -1)))

                failed_loss_optims = 0
                cum_lr_adjust_fac = 0
                cum_loss_reduction = 0

        checkpoint['train_loss'].append(
            sum(checkpoint['train_batch_loss'][-1]) / b)
        checkpoint['train_acc'].append(
            sum(checkpoint['train_batch_acc'][-1]) / b)
        checkpoint['train_C0_acc'].append(
            sum(checkpoint['train_batch_C0_acc'][-1]) / b)
        checkpoint['train_C1_acc'].append(
            sum(checkpoint['train_batch_C1_acc'][-1]) / b)
        checkpoint['train_time'].append(time.time() - train_start)

        io.cprint(
            '----\n  TRN | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f'
            % (checkpoint['train_time'][-1], checkpoint['train_loss'][-1],
               checkpoint['train_acc'][-1], checkpoint['train_C0_acc'][-1],
               checkpoint['train_C1_acc'][-1]))

        torch.cuda.empty_cache()

        ####################
        # Test & Validation
        ####################

        with torch.no_grad():

            if use_batch_norm:

                model.eval_bn()

                eval_bn_start = time.time()

                # run through all train samples again to accumulate layer-wise input distribution statistics (mean and variance) with fixed weights
                # these statistics are later used for the BatchNorm layers during inference
                for (bin_id, batch_ids) in binbatches:

                    batch_pts_ids = train_bins[bin_id][
                        batch_ids][:, :, 0]  # size: (B xN)
                    batch_input = pts[batch_pts_ids]  # size: (B x N x d)

                    # batch_input = batch_input.transpose(1,2).contiguous()             # size: (B x d x N)
                    batch_input = batch_input.transpose(1,
                                                        2)  # size: (B x d x N)
                    model(batch_input)

                io.cprint('Accumulated BN Layer statistics (%ds)' %
                          (time.time() - eval_bn_start))

            model.eval()

            if len(test_samples) > 0:

                test_start = time.time()

                test_loss, test_acc, test_acc_c0, test_acc_c1 = getTestLoss(
                    pts, test_samples, model, class_weights)

                checkpoint['test_loss'].append(test_loss)
                checkpoint['test_acc'].append(test_acc)
                checkpoint['test_C0_acc'].append(test_acc_c0)
                checkpoint['test_C1_acc'].append(test_acc_c1)

                checkpoint['test_time'].append(time.time() - test_start)

                io.cprint(
                    '  TST | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f'
                    %
                    (checkpoint['test_time'][-1], checkpoint['test_loss'][-1],
                     checkpoint['test_acc'][-1], checkpoint['test_C0_acc'][-1],
                     checkpoint['test_C1_acc'][-1]))

            else:
                io.cprint('  TST | n/a (no samples)')

            if len(val_samples) > 0:

                val_start = time.time()

                val_loss, val_acc, val_acc_c0, val_acc_c1 = getTestLoss(
                    val_pts, val_samples, model, class_weights)

                checkpoint['val_loss'].append(val_loss)
                checkpoint['val_acc'].append(val_acc)
                checkpoint['val_C0_acc'].append(val_acc_c0)
                checkpoint['val_C1_acc'].append(val_acc_c1)

                checkpoint['val_time'].append(time.time() - val_start)

                io.cprint(
                    '  VAL | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f'
                    % (checkpoint['val_time'][-1], checkpoint['val_loss'][-1],
                       checkpoint['val_acc'][-1], checkpoint['val_C0_acc'][-1],
                       checkpoint['val_C1_acc'][-1]))

            else:
                io.cprint('  VAL | n/a (no samples)')

        ####################
        # Scheduler Step
        ####################

        if not no_lr_schedule:
            scheduler.step()

        if epoch > 1 and dynamic_lr and sum(
                checkpoint['train_batch_lr_adjust'][-1]) > 0:
            io.cprint("----\n  dynamic lr adjust: %.10f" %
                      (0.5 *
                       (1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                        len(checkpoint['train_batch_lr_adjust'][-1]))))
            for param_group in opt.param_groups:
                param_group['lr'] *= 0.5 * (
                    1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                    len(checkpoint['train_batch_lr_adjust'][-1]))

        # Save model and optimizer state ..
        checkpoint['model_state_dict'].append(copy.deepcopy(
            model.state_dict()))
        checkpoint['optimizer_state_dict'].append(
            copy.deepcopy(opt.state_dict()))

        torch.save(checkpoint, exp_dir + '/detector_checkpoints.t7')

    io.cprint("\n-------------------------------------------------------" +
              ("\ntotal_time: %.2fh" % ((time.time() - start_time) / 3600)) +
              ("\ntrain_time: %.2fh" %
               (sum(checkpoint['train_time']) / 3600)) +
              ("\ntest_time: %.2fh" % (sum(checkpoint['test_time']) / 3600)) +
              ("\nval_time: %.2fh" % (sum(checkpoint['val_time']) / 3600)) +
              "\n-------------------------------------------------------" +
              "\nend_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")
Example #11
0
def train(args,
          log_dir,
          checkpoint_path,
          trainloader,
          testloader,
          tensorboard,
          c,
          model_name,
          ap,
          cuda=True,
          model_params=None):
    loss1_weight = c.train_config['loss1_weight']
    use_mixup = False if 'mixup' not in c.model else c.model['mixup']
    if use_mixup:
        mixup_alpha = 1 if 'mixup_alpha' not in c.model else c.model[
            'mixup_alpha']
        mixup_augmenter = Mixup(mixup_alpha=mixup_alpha)
        print("Enable Mixup with alpha:", mixup_alpha)

    model = return_model(c, model_params)

    if c.train_config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    elif c.train_config['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    elif c.train_config['optimizer'] == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=c.train_config['learning_rate'],
                          weight_decay=c.train_config['weight_decay'])
    else:
        raise Exception("The %s  not is a optimizer supported" %
                        c.train['optimizer'])

    step = 0
    if checkpoint_path is not None:
        print("Continue training from checkpoint: %s" % checkpoint_path)
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        step = 0
    else:
        print("Starting new training run")
        step = 0

    if c.train_config['lr_decay']:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.train_config['warmup_steps'],
                           last_epoch=step - 1)
    else:
        scheduler = None
    # convert model from cuda
    if cuda:
        model = model.cuda()

    # define loss function
    if use_mixup:
        criterion = Clip_BCE()
    else:
        criterion = nn.BCELoss()
    eval_criterion = nn.BCELoss(reduction='sum')

    best_loss = float('inf')

    # early stop definitions
    early_epochs = 0

    model.train()
    for epoch in range(c.train_config['epochs']):
        for feature, target in trainloader:

            if cuda:
                feature = feature.cuda()
                target = target.cuda()

            if use_mixup:
                batch_len = len(feature)
                if (batch_len % 2) != 0:
                    batch_len -= 1
                    feature = feature[:batch_len]
                    target = target[:batch_len]

                mixup_lambda = torch.FloatTensor(
                    mixup_augmenter.get_lambda(batch_len)).to(feature.device)
                output = model(feature[:batch_len], mixup_lambda)
                target = do_mixup(target, mixup_lambda)
            else:
                output = model(feature)
            # Calculate loss
            if c.dataset['class_balancer_batch'] and not use_mixup:
                idxs = (target == c.dataset['control_class'])
                loss_control = criterion(output[idxs], target[idxs])
                idxs = (target == c.dataset['patient_class'])
                loss_patient = criterion(output[idxs], target[idxs])
                loss = (loss_control + loss_patient) / 2
            else:
                loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update lr decay scheme
            if scheduler:
                scheduler.step()
            step += 1

            loss = loss.item()
            if loss > 1e8 or math.isnan(loss):
                print("Loss exploded to %.02f at step %d!" % (loss, step))
                break

            # write loss to tensorboard
            if step % c.train_config['summary_interval'] == 0:
                tensorboard.log_training(loss, step)
                if c.dataset['class_balancer_batch'] and not use_mixup:
                    print("Write summary at step %d" % step, ' Loss: ', loss,
                          'Loss control:', loss_control.item(),
                          'Loss patient:', loss_patient.item())
                else:
                    print("Write summary at step %d" % step, ' Loss: ', loss)

            # save checkpoint file  and evaluate and save sample to tensorboard
            if step % c.train_config['checkpoint_interval'] == 0:
                save_path = os.path.join(log_dir, 'checkpoint_%d.pt' % step)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': step,
                        'config_str': str(c),
                    }, save_path)
                print("Saved checkpoint to: %s" % save_path)
                # run validation and save best checkpoint
                val_loss = validation(eval_criterion,
                                      ap,
                                      model,
                                      c,
                                      testloader,
                                      tensorboard,
                                      step,
                                      cuda=cuda,
                                      loss1_weight=loss1_weight)
                best_loss, _ = save_best_checkpoint(
                    log_dir, model, optimizer, c, step, val_loss, best_loss,
                    early_epochs
                    if c.train_config['early_stop_epochs'] != 0 else None)

        print('=================================================')
        print("Epoch %d End !" % epoch)
        print('=================================================')
        # run validation and save best checkpoint at end epoch
        val_loss = validation(eval_criterion,
                              ap,
                              model,
                              c,
                              testloader,
                              tensorboard,
                              step,
                              cuda=cuda,
                              loss1_weight=loss1_weight)
        best_loss, early_epochs = save_best_checkpoint(
            log_dir, model, optimizer, c, step, val_loss, best_loss,
            early_epochs if c.train_config['early_stop_epochs'] != 0 else None)
        if c.train_config['early_stop_epochs'] != 0:
            if early_epochs is not None:
                if early_epochs >= c.train_config['early_stop_epochs']:
                    break  # stop train
    return best_loss
class TableTrainer(Trainer):
    """for classification task"""
    def __init__(self, model, train_set, val_set, test_set, configs):
        super().__init__()
        print("Start trainer..")
        print(configs)

        # load config
        self._configs = configs
        self._lr = self._configs['lr']
        self._batch_size = self._configs['batch_size']
        self._momentum = self._configs['momentum']
        self._weight_decay = self._configs['weight_decay']
        self._distributed = self._configs['distributed']
        self._num_workers = self._configs['num_workers']
        self._device = torch.device(self._configs['device'])
        self._max_epoch_num = self._configs['max_epoch_num']
        self._max_plateau_count = self._configs['max_plateau_count']

        # load dataloader and model
        self._train_set = train_set
        self._val_set = val_set
        self._test_set = test_set
        self._model = model(
            in_channels=configs['in_channels'],
            num_classes=configs['num_classes'],
        )

        self._model = self._model.to(self._device)

        if self._distributed == 1:
            torch.distributed.init_process_group(backend='nccl')
            self._model = nn.parallel.DistributedDataParallel(self._model)

            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
                worker_init_fn=lambda x: np.random.seed(x))
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x))

            self._test_loader = DataLoader(
                self._test_set,
                batch_size=1,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
                worker_init_fn=lambda x: np.random.seed(x))
        else:
            self._train_loader = DataLoader(
                self._train_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=True,
            )
            self._val_loader = DataLoader(
                self._val_set,
                batch_size=self._batch_size,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )
            self._test_loader = DataLoader(
                self._test_set,
                batch_size=1,
                num_workers=self._num_workers,
                pin_memory=True,
                shuffle=False,
            )

        # define loss function (criterion) and optimizer
        class_weights = [0.2, 0.8]
        class_weights = torch.FloatTensor(np.array(class_weights))
        self._criterion = nn.CrossEntropyLoss(class_weights).to(self._device)
        # self._criterion = nn.CrossEntropyLoss().to(self._device)

        self._optimizer = RAdam(
            params=self._model.parameters(),
            lr=self._lr,
            weight_decay=self._weight_decay,
        )

        self._scheduler = ReduceLROnPlateau(
            self._optimizer,
            patience=self._configs['plateau_patience'],
            min_lr=1e-8,
            verbose=True)
        ''' TODO set step size equal to configs
        self._scheduler = StepLR(
            self._optimizer,
            step_size=self._configs['steplr']
        )
        '''

        # training info
        self._start_time = datetime.datetime.now()
        self._start_time = self._start_time.replace(microsecond=0)

        log_dir = os.path.join(
            self._configs['cwd'], self._configs['log_dir'],
            "{}_{}_{}".format(self._configs['arch'],
                              self._configs['model_name'],
                              self._start_time.strftime('%Y%b%d_%H.%M')))
        self._writer = SummaryWriter(log_dir)
        self._train_loss_list = []
        self._train_acc_list = []
        self._val_loss_list = []
        self._val_acc_list = []
        self._best_val_loss = 1e9
        self._best_val_acc = 0
        self._best_train_loss = 1e9
        self._best_train_acc = 0
        self._test_acc = 0.
        self._plateau_count = 0
        self._current_epoch_num = 0

        # for checkpoints
        self._checkpoint_dir = os.path.join(self._configs['cwd'],
                                            'saved/checkpoints')
        if not os.path.exists(self._checkpoint_dir):
            os.makedirs(self._checkpoint_dir, exist_ok=True)

        self._checkpoint_path = os.path.join(
            self._checkpoint_dir,
            "{}_{}_{}".format(self._configs['arch'],
                              self._configs['model_name'],
                              self._start_time.strftime('%Y%b%d_%H.%M')))

    def _train(self):
        self._model.train()
        train_loss = 0.
        train_acc = 0.

        for i, (images, targets) in tqdm(enumerate(self._train_loader),
                                         total=len(self._train_loader),
                                         leave=False):
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # compute output, measure accuracy and record loss
            outputs = self._model(images)

            loss = self._criterion(outputs, targets)
            # acc = accuracy(outputs, targets)[0]
            acc = eval_metrics(targets, outputs, 2)[0]

            train_loss += loss.item()
            train_acc += acc.item()

            # compute gradient and do SGD step
            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            if self._configs['little'] == 1:
                # outputs = torch.softmax(outputs, dim=1)
                mask = torch.squeeze(outputs, 0)
                mask = mask.detach().cpu().numpy() * 255
                mask = np.transpose(mask, (1, 2, 0)).astype(np.uint8)
                cv2.imwrite(
                    os.path.join('debug',
                                 'e{}.png'.format(self._current_epoch_num)),
                    mask[..., 1])

        i += 1
        self._train_loss_list.append(train_loss / i)
        self._train_acc_list.append(train_acc / i)

    def _val(self):
        self._model.eval()
        val_loss = 0.
        val_acc = 0.

        loop_time = 1
        with torch.no_grad():
            for idx in range(loop_time):  # eval 5 phat cho mau lua =)))
                for i, (images, targets) in tqdm(enumerate(self._val_loader),
                                                 total=len(self._val_loader),
                                                 leave=False):
                    images = images.cuda(non_blocking=True)
                    targets = targets.cuda(non_blocking=True)

                    # compute output, measure accuracy and record loss
                    outputs = self._model(images)

                    loss = self._criterion(outputs, targets)
                    # acc = accuracy(outputs, targets)[0]
                    acc = eval_metrics(targets, outputs, 2)[0]

                    val_loss += loss.item()
                    val_acc += acc.item()

            i += 1
            self._val_loss_list.append(val_loss / i / loop_time)
            self._val_acc_list.append(val_acc / i / loop_time)

    def _calc_acc_on_private_test(self):
        self._model.eval()
        test_acc = 0.
        print('Calc acc on private test..')
        f = open('private_test_log.txt', 'w')
        with torch.no_grad():
            for i, (images, targets) in tqdm(enumerate(self._test_loader),
                                             total=len(self._test_loader),
                                             leave=False):

                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                print(outputs.shape, outputs)
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()
                f.writelines("{}_{}\n".format(i, acc.item()))

            test_acc = test_acc / (i + 1)
        print("Accuracy on private test: {:.3f}".format(test_acc))
        f.close()
        return test_acc

    def _calc_acc_on_private_test_with_tta(self):
        self._model.eval()
        test_acc = 0.
        print('Calc acc on private test with tta..')
        f = open(
            'private_test_log_{}_{}.txt'.format(self._configs['arch'],
                                                self._configs['model_name']),
            'w')

        with torch.no_grad():
            for idx in tqdm(range(len(self._test_set)),
                            total=len(self._test_set),
                            leave=False):
                images, targets = self._test_set[idx]
                targets = torch.LongTensor([targets])

                images = make_batch(images)
                images = images.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = self._model(images)
                outputs = F.softmax(outputs, 1)

                # outputs.shape [tta_size, 7]
                outputs = torch.sum(outputs, 0)

                outputs = torch.unsqueeze(outputs, 0)
                # print(outputs.shape)
                # TODO: try with softmax first and see the change
                acc = accuracy(outputs, targets)[0]
                test_acc += acc.item()
                f.writelines("{}_{}\n".format(idx, acc.item()))

            test_acc = test_acc / (idx + 1)
        print("Accuracy on private test with tta: {:.3f}".format(test_acc))
        f.close()
        return test_acc

    def train(self):
        """make a training job"""
        print(self._model)

        try:
            while not self._is_stop():
                self._increase_epoch_num()
                self._train()
                self._val()

                self._update_training_state()
                self._logging()
        except KeyboardInterrupt:
            traceback.print_exc()
            pass

        # training stop
        try:
            # state = torch.load(self._checkpoint_path)
            # if self._distributed:
            #     self._model.module.load_state_dict(state['net'])
            # else:
            #     self._model.load_state_dict(state['net'])

            # if not self._test_set.is_tta():
            #     self._test_acc = self._calc_acc_on_private_test()
            # else:
            #     self._test_acc = self._calc_acc_on_private_test_with_tta()

            # self._test_acc = self._calc_acc_on_private_test()
            self._save_weights()
        except Exception as e:
            traceback.print_exc()
            pass

        consume_time = str(datetime.datetime.now() - self._start_time)
        self._writer.add_text(
            'Summary', 'Converged after {} epochs, consume {}'.format(
                self._current_epoch_num, consume_time[:-7]))
        self._writer.add_text(
            'Results',
            'Best validation accuracy: {:.3f}'.format(self._best_val_acc))
        self._writer.add_text(
            'Results',
            'Best training accuracy: {:.3f}'.format(self._best_train_acc))
        self._writer.add_text(
            'Results', 'Private test accuracy: {:.3f}'.format(self._test_acc))
        self._writer.close()

    def _update_training_state(self):
        if self._val_loss_list[-1] < self._best_val_loss:
            self._save_weights()
            self._best_val_acc = self._val_acc_list[-1]
            self._best_val_loss = self._val_loss_list[-1]
            self._best_train_acc = self._train_acc_list[-1]
            self._best_train_loss = self._train_loss_list[-1]

        if self._train_loss_list[-1] == min(self._train_loss_list):
            self._plateau_count = 0
        else:
            self._plateau_count += 1

        self._scheduler.step(self._train_loss_list[-1])
        # self._scheduler.step()

    def _logging(self):
        consume_time = str(datetime.datetime.now() - self._start_time)

        message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
            self._current_epoch_num, self._train_loss_list[-1],
            self._val_loss_list[-1], self._best_val_loss,
            self._train_acc_list[-1], self._val_acc_list[-1],
            self._best_val_acc, self._plateau_count, consume_time[:-7])

        self._writer.add_scalar('Accuracy/Train', self._train_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar('Accuracy/Val', self._val_acc_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar('Loss/Train', self._train_loss_list[-1],
                                self._current_epoch_num)
        self._writer.add_scalar('Loss/Val', self._val_loss_list[-1],
                                self._current_epoch_num)

        print(message)

    def _is_stop(self):
        """check stop condition"""
        return (self._plateau_count > self._max_plateau_count
                or self._current_epoch_num > self._max_epoch_num)

    def _increase_epoch_num(self):
        self._current_epoch_num += 1

    def _save_weights(self, test_acc=0.):
        if self._distributed == 0:
            state_dict = self._model.state_dict()
        else:
            state_dict = self._model.module.state_dict()

        state = {
            **self._configs,
            'net': state_dict,
            'best_val_loss': self._best_val_loss,
            'best_val_acc': self._best_val_acc,
            'best_train_loss': self._best_train_loss,
            'best_train_acc': self._best_train_acc,
            'train_losses': self._train_loss_list,
            'val_loss_list': self._val_loss_list,
            'train_acc_list': self._train_acc_list,
            'val_acc_list': self._val_acc_list,
            'test_acc': self._test_acc,
        }

        torch.save(state, self._checkpoint_path)