示例#1
0
class Trainer():
    def __init__(self, args):
        now_time = datetime.datetime.strftime(datetime.datetime.now(),
                                              '%m%d-%H%M%S')
        args.cur_dir = os.path.join(args.exp_dir, now_time)
        args.log_path = os.path.join(args.cur_dir, 'train.log')
        args.best_model_path = os.path.join(args.cur_dir, 'best_model.pth')

        self.args = args
        mkdir(self.args.exp_dir)
        mkdir(self.args.cur_dir)
        self.log = Logger(self.args.log_path, level='debug').logger
        self.log.critical("args: \n{}".format(to_str_args(self.args)))

        self.train_loader = torch.utils.data.DataLoader(
            dataset=CUB200Dataset(root=self.args.root, train=True),
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            pin_memory=self.args.pin_memory,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=CUB200Dataset(root=self.args.root, train=False),
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            pin_memory=self.args.pin_memory,
            shuffle=False)

        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(in_features=self.model.fc.in_features,
                                  out_features=self.args.num_classes)
        self.model.cuda()

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(
            params=self.model.parameters(), lr=self.args.lr
        ) if self.args.optim_type == 'Adam' else torch.optim.SGD(
            params=self.model.parameters(),
            lr=self.args.lr,
            momentum=self.args.momentum,
            weight_decay=self.args.decay)

        self.log.critical("model: \n{}".format(self.model))
        self.log.critical("torchsummary: \n{}".format(
            summary(model=self.model, input_size=(3, 224, 224))))
        self.log.critical("criterion: \n{}".format(self.criterion))
        self.log.critical("optimizer: \n{}".format(self.optimizer))

    def train(self):
        self.model.train()
        losses = AverageMeter()
        correct = 0
        pbar = ImProgressBar(len(self.train_loader))
        for i, (imgs, targets) in enumerate(self.train_loader):
            imgs, targets = imgs.cuda(), targets.cuda()
            outputs = self.model(imgs)

            _, predicted = torch.max(outputs.data, dim=1)
            correct += (predicted == targets).sum().item()

            loss = self.criterion(outputs, targets)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            losses.update(loss.item(), 1)
            pbar.update(i)
        pbar.finish()
        return losses.avg, correct / len(self.train_loader.dataset)

    def eval(self, loader):
        self.model.eval()
        losses = AverageMeter()
        correct = 0
        with torch.no_grad():
            pbar = ImProgressBar(len(loader))
            for i, (imgs, targets) in enumerate(loader):
                imgs, targets = imgs.cuda(), targets.cuda()
                outputs = self.model(imgs)

                _, predicted = torch.max(outputs.data, dim=1)
                correct += (predicted == targets).sum().item()

                loss = self.criterion(outputs, targets)
                losses.update(loss.item(), 1)

                pbar.update(i)
            pbar.finish()
        return losses.avg, correct / len(loader.dataset)

    def fit(self):
        best_epoch, best_test_acc = 0, 0
        for epoch in range(0, self.args.epochs):
            end = time.time()
            train_loss, train_acc = self.train()
            test_loss, test_acc = self.eval(self.test_loader)

            if test_acc > best_test_acc:
                best_epoch = epoch
                best_test_acc = test_acc
                checkpoint = {
                    'epoch': epoch + 1,
                    'args': vars(self.args),
                    'state_dict': self.model.state_dict(),
                    'best_test_acc': best_test_acc,
                    'optimizer': self.optimizer.state_dict()
                }
                torch.save(checkpoint, self.args.best_model_path)

            self.log.info(
                '[Epoch: {:3}/{:3}][Time: {:.3f}] Train loss: {:.3f}, Test loss: {:.3f}, Train acc: {:.3f}%, Test acc: {:.3f}% (best_test_acc: {:.3f}%, epoch: {})'
                .format(epoch + 1, self.args.epochs,
                        time.time() - end, train_loss, test_loss,
                        train_acc * 100, test_acc * 100, best_test_acc * 100,
                        best_epoch))