class Trainer(object):
    def __init__(self, args):
        self.args = args
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        # data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
        # train_dataset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
        # val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
        # Setup Dataloader
        data_kwargs = {'is_transform': True, 'img_size': args.crop_size}
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 root='./datasets/' +
                                                 args.dataset,
                                                 split=args.train_split,
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               root='./datasets/' +
                                               args.dataset,
                                               split='val',
                                               **data_kwargs)
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          shuffle=False)

        n_classes = train_dataset.n_classes

        # create network
        self.model = get_fast_scnn(dataset=args.dataset, aux=args.aux)
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2])
        self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))

                checkpoint = torch.load(
                    args.resume, map_location=lambda storage, loc: storage)
                # op surgery to change classes
                heads = ["classifier.conv.1.weight", "classifier.conv.1.bias"]
                curr_state_dict = self.model.state_dict()
                for param in heads:
                    checkpoint[param] = curr_state_dict[param]
                # end op surgery
                self.model.load_state_dict(checkpoint)

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyOHEMLoss(
            aux=args.aux, aux_weight=args.aux_weight,
            ignore_index=-1).to(args.device)
        # self.criterion = multi_scale_cross_entropy2d.to(args.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly',
                                        base_lr=args.lr,
                                        nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader),
                                        power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

        # for name, param in self.model.state_dict().items():
        #     print(name, param.shape)

    def train(self):
        cur_iters = 0
        start_time = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.model.train()

            for i, (images, targets) in enumerate(self.train_loader):
                cur_lr = self.lr_scheduler(cur_iters)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = cur_lr

                if images.shape[0] == 1:
                    continue

                images = images.to(self.args.device)
                targets = targets.to(self.args.device)

                outputs = self.model(images)
                # print(len(outputs), outputs[0].shape, targets.shape, type(targets[0]))
                # loss = self.criterion(outputs, targets)
                loss = multi_scale_cross_entropy2d(outputs, targets)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                cur_iters += 1
                if cur_iters % 10 == 0:
                    print(
                        'Epoch: [%2d/%2d] Iter [%4d/%4d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f'
                        % (epoch, args.epochs, i + 1, len(self.train_loader),
                           time.time() - start_time, cur_lr, loss.item()))

            if self.args.no_val:
                # save every epoch
                save_checkpoint(self.model, self.args, is_best=False)
            else:
                self.validation(epoch)

        save_checkpoint(self.model, self.args, is_best=False)

    def validation(self, epoch):
        is_best = False
        self.metric.reset()
        self.model.eval()
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.args.device)

            outputs = self.model(image)
            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()
            self.metric.update(pred, target.numpy())
            pixAcc, mIoU = self.metric.get()
            print(
                'Epoch %d, Sample %d, validation pixAcc: %.3f%%, mIoU: %.3f%%'
                % (epoch, i + 1, pixAcc * 100, mIoU * 100))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
Exemple #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split=args.train_split,
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          shuffle=False)

        # create network
        self.model = get_fast_scnn(dataset=args.dataset, aux=args.aux)
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2])
        self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyLoss(args.aux,
                                                    args.aux_weight,
                                                    ignore_label=-1).to(
                                                        args.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly',
                                        base_lr=args.lr,
                                        nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader),
                                        power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

    def train(self):
        self.model.train()
        cur_iters = 0
        start_time = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            for i, (images, targets) in enumerate(self.train_loader):
                cur_lr = self.lr_scheduler(cur_iters)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = cur_lr

                images = images.to(self.args.device)
                targets = targets.to(self.args.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                cur_iters += 1
                if cur_iters % 10 == 0:
                    print(
                        'Epoch: [%2d/%2d] Iter [%4d/%4d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f'
                        % (epoch, args.epochs, i + 1, len(self.train_loader),
                           time.time() - start_time, cur_lr, loss.item()))

            if not args.no_val:
                self.validation(epoch)

            # save every 10 epoch
            if epoch != 0 and epoch % 10 == 0:
                print('Saving state, epoch:', epoch)
                self.save_checkpoint()
        self.save_checkpoint()

    def validation(self, epoch):
        self.metric.reset()
        self.model.eval()
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.args.device)

            outputs = self.model(image)
            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()
            self.metric.update(pred, target.numpy())
            pixAcc, mIoU = self.metric.get()
            print(
                'Epoch %d, Sample %d, validation pixAcc: %.3f%%, mIoU: %.3f%%'
                % (epoch, i + 1, pixAcc * 100, mIoU * 100))

    def save_checkpoint(self):
        """Save Checkpoint"""
        directory = os.path.expanduser(self.args.save_folder)
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = '{}_{}.pth'.format(args.model, args.dataset)
        save_path = os.path.join(directory, filename)
        torch.save(self.model.state_dict(), save_path)
class Trainer(object):
    def __init__(self, args, cfg=None):
        # train_dataset = [build_dataset(cfg.data.train)]
        # self.dataset= train_dataset
        # val_dataset = [build_dataset(cfg.data.test)]
        # if len(cfg.workflow) == 2:
        #     train_dataset.append(build_dataset(cfg.data.val))
        # train_data_loaders = [
        #     build_dataloader(
        #         ds,
        #         cfg.data.imgs_per_gpu,
        #         cfg.data.workers_per_gpu,
        #         # cfg.gpus,
        #         dist=False) for ds in train_dataset
        # ]
        # val_data_loader = build_dataloader(
        #     val_dataset,
        #     imgs_per_gpu=1,
        #     workers_per_gpu=cfg.data.workers_per_gpu,
        #     dist=False,
        #     shuffle=False)
        # self.train_loader = train_data_loaders[0]
        # self.val_loader = val_data_loader

        self.args = args
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[123.675, 116.28, 103.53],
                                 std=[58.395, 57.12, 57.375]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split=args.train_split,
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          shuffle=False)

        # create network
        self.model = get_fast_scnn(dataset=args.dataset, aux=args.aux)
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2])
        self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyOHEMLoss(
            aux=args.aux, aux_weight=args.aux_weight,
            ignore_index=-1).to(args.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly',
                                        base_lr=args.lr,
                                        nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader),
                                        power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        cur_iters = 0
        start_time = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.model.train()

            for i, (images, targets) in enumerate(self.train_loader):
                cur_lr = self.lr_scheduler(cur_iters)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = cur_lr

                images = images.to(self.args.device)
                targets = targets.to(self.args.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                cur_iters += 1
                if cur_iters % 10 == 0:
                    print(
                        'Epoch: [%2d/%2d] Iter [%4d/%4d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f'
                        % (epoch, args.epochs, i + 1, len(self.train_loader),
                           time.time() - start_time, cur_lr, loss.item()))

            if self.args.no_val:
                # save every epoch
                save_checkpoint(self.model, self.args, is_best=False)
            else:
                self.validation(epoch)

        save_checkpoint(self.model, self.args, is_best=False)

    def validation(self, epoch):
        is_best = False
        self.metric.reset()
        self.model.eval()
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.args.device)

            outputs = self.model(image)
            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()
            self.metric.update(pred, target.numpy())
            pixAcc, mIoU = self.metric.get()
            print(
                'Epoch %d, Sample %d, validation pixAcc: %.3f%%, mIoU: %.3f%%'
                % (epoch, i + 1, pixAcc * 100, mIoU * 100))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Tensorboard Summary
        self.summary = TensorboardSummary('./experiment')
        self.writer = self.summary.create_summary()

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split=args.train_split,
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)  #split = val
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          shuffle=False)

        # create network
        self.model = get_fast_scnn(dataset=args.dataset, aux=args.aux)
        #self.model = ContextNet(19)  #changed
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2])
        self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyOHEMLoss(
            aux=args.aux, aux_weight=args.aux_weight,
            ignore_index=-1).to(args.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly',
                                        base_lr=args.lr,
                                        nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader),
                                        power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

        self.evaluator = Eval(8)  #added

    def train(self):
        cur_iters = 0
        start_time = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.model.train()

            for i, (images, targets) in enumerate(self.train_loader):
                cur_lr = self.lr_scheduler(cur_iters)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = cur_lr

                images = images.to(self.args.device)
                targets = targets.to(self.args.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                cur_iters += 1
                if cur_iters % 10 == 0:
                    print(
                        'Epoch: [%2d/%2d] Iter [%4d/%4d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f'
                        % (epoch, args.epochs, i + 1, len(self.train_loader),
                           time.time() - start_time, cur_lr, loss.item()))

            if self.args.no_val:
                # save every epoch
                save_checkpoint(self.model, self.args, is_best=False)
            else:
                self.validation(epoch)

            self.writer.add_scalar('val/total_loss_epoch', loss.item(),
                                   epoch)  #added

        save_checkpoint(self.model, self.args, is_best=False)

    def validation(self, epoch):
        is_best = False
        self.metric.reset()
        self.model.eval()
        self.evaluator.reset()  #added
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.args.device)

            outputs = self.model(image)
            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()
            self.metric.update(pred, target.numpy())
            pixAcc, mIoU = self.metric.get()
            print(
                'Epoch %d, Sample %d, validation pixAcc: %.3f%%, mIoU: %.3f%%'
                % (epoch, i + 1, pixAcc * 100, mIoU * 100))
        '''
        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        deeplab_mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        
        self.writer.add_scalar('val/mIoU', deeplab_mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        '''

        self.writer.add_scalar('val/my_pixAcc', pixAcc, epoch)  #added
        self.writer.add_scalar('val/my_mIoU', mIoU, epoch)  #added

        print('Validation:')
        #print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        #print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, deeplab_mIoU, FWIoU))
        #print('Loss: %.3f' % test_loss)

        new_pred = (pixAcc + mIoU) / 2.0
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)