Example #1
0
    def __init__(self, args):
        if args.se_loss:
            args.checkname = args.checkname + "_se"

        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])])
        # dataset
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                       'crop_size': args.crop_size}
        trainset = get_segmentation_dataset(args.dataset, split='train', mode='train',
                                           **data_kwargs)
        testset = get_segmentation_dataset(args.dataset, split='val', mode ='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
                                           drop_last=True, shuffle=True, **kwargs)
        self.valloader = data.DataLoader(testset, batch_size=args.batch_size,
                                         drop_last=False, shuffle=False, **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model, dataset=args.dataset,
                                       backbone = args.backbone, aux = args.aux,
                                       se_loss = args.se_loss, norm_layer = BatchNorm2d,
                                       base_size=args.base_size, crop_size=args.crop_size)
        print(model)

        # count parameter number
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        print("Total number of parameters: %d"%pytorch_total_params)

        # optimizer using different LR
        params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
        if hasattr(model, 'head'):
            if args.diflr:
                params_list.append({'params': model.head.parameters(), 'lr': args.lr*10})
            else:
                params_list.append({'params': model.head.parameters(), 'lr': args.lr})
        if hasattr(model, 'auxlayer'):
            if args.diflr:
                params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr*10})
            else:
                params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr})

        optimizer = torch.optim.SGD(params_list,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

        #optimizer = torch.optim.ASGD(params_list,
        #                            lr=args.lr,
        #                            weight_decay=args.weight_decay)

        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,
                                            nclass=self.nclass)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        if args.resume is not None and len(args.resume)>0:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # load weights for the same model
                # self.model.module.load_state_dict(checkpoint['state_dict'])



                # model and checkpoint have different strucutures
                pretrained_dict = checkpoint['state_dict']
                model_dict = self.model.module.state_dict()

                for name, param in pretrained_dict.items():
                    if name not in model_dict:
                        continue
                    if isinstance(param, Parameter):
                        # backwards compatibility for serialized parameters
                        param = param.data
                    model_dict[name].copy_(param)

            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader),lr_step=args.lr_step)
        self.best_pred = 0.0
Example #2
0
 def __init__(self, args):
     self.args = args
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])
     ])
     # dataset
     data_kwargs = {
         'transform': input_transform,
         'base_size': args.base_size,
         'crop_size': args.crop_size
     }
     trainset = get_segmentation_dataset(args.dataset,
                                         split=args.train_split,
                                         mode='train',
                                         **data_kwargs)
     testset = get_segmentation_dataset(args.dataset,
                                        split='val',
                                        mode='val',
                                        **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True} \
         if args.cuda else {}
     self.trainloader = data.DataLoader(trainset,
                                        batch_size=args.batch_size,
                                        drop_last=True,
                                        shuffle=True,
                                        **kwargs)
     self.valloader = data.DataLoader(testset,
                                      batch_size=args.batch_size,
                                      drop_last=False,
                                      shuffle=False,
                                      **kwargs)
     self.nclass = trainset.num_class
     # model
     model = get_segmentation_model(
         args.model,
         dataset=args.dataset,
         backbone=args.backbone,
         dilated=args.dilated,
         lateral=args.lateral,
         jpu=args.jpu,
         aux=args.aux,
         se_loss=args.se_loss,
         norm_layer=torch.nn.BatchNorm2d,  ## BatchNorm2d
         base_size=args.base_size,
         crop_size=args.crop_size)
     print(model)
     # optimizer using different LR
     params_list = [
         {
             'params': model.pretrained.parameters(),
             'lr': args.lr
         },
     ]
     if hasattr(model, 'jpu'):
         params_list.append({
             'params': model.jpu.parameters(),
             'lr': args.lr * 10
         })
     if hasattr(model, 'head'):
         params_list.append({
             'params': model.head.parameters(),
             'lr': args.lr * 10
         })
     if hasattr(model, 'auxlayer'):
         params_list.append({
             'params': model.auxlayer.parameters(),
             'lr': args.lr * 10
         })
     optimizer = torch.optim.SGD(params_list,
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
     # criterions
     self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                         aux=args.aux,
                                         nclass=self.nclass,
                                         se_weight=args.se_weight,
                                         aux_weight=args.aux_weight)
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.cuda:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     # resuming checkpoint
     if args.resume is not None:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'".format(
                 args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         if args.cuda:
             self.model.module.load_state_dict(checkpoint['state_dict'])
         else:
             self.model.load_state_dict(checkpoint['state_dict'])
         if not args.ft:
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         print("=> loaded checkpoint '{}' (epoch {})".format(
             args.resume, checkpoint['epoch']))
     # clear start epoch if fine-tuning
     if args.ft:
         args.start_epoch = 0
     # lr scheduler
     self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                         args.epochs, len(self.trainloader))
     self.best_pred = 0.0
Example #3
0
class Trainer():
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_dataset(args.dataset,
                               split=args.train_split,
                               mode='train',
                               **data_kwargs)
        valset = get_dataset(
            args.dataset,
            split='val',
            mode='ms_val' if args.multi_scale_eval else 'fast_val',
            **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        if self.args.multi_scale_eval:
            kwargs['collate_fn'] = test_batchify_fn
        self.valloader = data.DataLoader(valset,
                                         batch_size=args.test_batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        if args.norm_layer == 'bn':
            norm_layer = BatchNorm2d
        elif args.norm_layer == 'sync_bn':
            assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!"
            norm_layer = SyncBatchNorm
        else:
            raise ValueError('Invalid norm_layer {}'.format(args.norm_layer))
        model = get_segmentation_model(
            args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            se_loss=args.se_loss,
            norm_layer=norm_layer,
            base_size=args.base_size,
            crop_size=args.crop_size,
            multi_grid=True,
            multi_dilation=[2, 4, 8],
            only_pam=True,
        )
        print(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        # criterions
        self.criterion = SegmentationMultiLosses()
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.multi_gpu:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        else:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()
        self.single_device_model = self.model.module if self.args.multi_gpu else self.model
        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            self.single_device_model.load_state_dict(checkpoint['state_dict'])
            if not args.ft and not (args.only_val or args.only_vis
                                    or args.only_infer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format(
                args.resume, checkpoint['epoch'], checkpoint['best_pred']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, 0.6)
        self.best_pred = 0.0

    def save_ckpt(self, epoch, score):
        is_best = False
        if score >= self.best_pred:
            is_best = True
            self.best_pred = score
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.single_device_model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, self.args, is_best)

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.lr_scheduler.step()
        tbar = tqdm(self.trainloader, miniters=20)
        for i, (image, target) in enumerate(tbar):

            if not self.args.multi_gpu:
                image = image.cuda()
                target = target.cuda()
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            if self.args.multi_gpu:
                loss = self.criterion(outputs, target)
            else:
                loss = self.criterion(*(outputs + (target, )))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            ep_log = 'ep {}'.format(epoch + 1)
            lr_log = 'lr ' + '{:.6f}'.format(
                self.optimizer.param_groups[0]['lr']).rstrip('0')
            loss_log = 'loss {:.3f}'.format(train_loss / (i + 1))
            tbar.set_description(', '.join([ep_log, lr_log, loss_log]))

    def validation(self, epoch):
        def _get_pred(batch_im):
            with torch.no_grad():
                # metric.update also accepts list, so no need to gather results from multi gpus
                if self.args.multi_scale_eval:
                    assert len(batch_im) <= torch.cuda.device_count(
                    ), "Multi-scale testing only allows batch size <= number of GPUs"
                    scattered_pred = self.ms_evaluator.parallel_forward(
                        batch_im)
                else:
                    outputs = self.model(batch_im)
                    scattered_pred = [
                        out[0] for out in outputs
                    ] if self.args.multi_gpu else [outputs[0]]
            return scattered_pred

        # Lazy creation
        if not hasattr(self, 'ms_evaluator'):
            self.ms_evaluator = MultiEvalModule(self.single_device_model,
                                                self.nclass,
                                                scales=self.args.eval_scales,
                                                crop=self.args.crop_eval)
            self.metric = utils.SegmentationMetric(self.nclass)
        self.model.eval()
        tbar = tqdm(self.valloader, desc='\r')
        for i, (batch_im, target) in enumerate(tbar):
            # No need to put target to GPU, since the metrics are calculated by numpy.
            # And no need to put data to GPU manually if we use data parallel.
            if not self.args.multi_gpu and not isinstance(
                    batch_im, (list, tuple)):
                batch_im = batch_im.cuda()
            scattered_pred = _get_pred(batch_im)
            scattered_target = []
            ind = 0
            for p in scattered_pred:
                target_tmp = target[ind:ind + len(p)]
                # Multi-scale testing. In fact, len(target_tmp) == 1
                if isinstance(target_tmp, (list, tuple)):
                    assert len(target_tmp) == 1
                    target_tmp = torch.stack(target_tmp)
                scattered_target.append(target_tmp)
                ind += len(p)
            self.metric.update(scattered_target, scattered_pred)
            pixAcc, mIoU = self.metric.get()
            tbar.set_description('ep {}, pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                epoch + 1, pixAcc, mIoU))
        return self.metric.get()

    def visualize(self, epoch):
        if (self.args.dir_of_im_to_vis
                == 'None') and (self.args.im_list_file_to_vis == 'None'):
            return
        if not hasattr(self, 'vis_im_paths'):
            if self.args.dir_of_im_to_vis != 'None':
                print('=> Visualize Dir {}'.format(self.args.dir_of_im_to_vis))
                im_paths = list(
                    walkdir(self.args.dir_of_im_to_vis, exts=['.jpg', '.png']))
            else:
                print('=> Visualize Image List {}'.format(
                    self.args.im_list_file_to_vis))
                im_paths = read_lines(self.args.im_list_file_to_vis)
            print('=> Save Dir {}'.format(self.args.vis_save_dir))
            im_paths = sorted(im_paths)
            # np.random.RandomState(seed=1).shuffle(im_paths)
            self.vis_im_paths = im_paths[:self.args.max_num_vis]
        cfg = {
            'save_path':
            os.path.join(self.args.vis_save_dir,
                         'vis_epoch{}.png'.format(epoch)),
            'multi_scale':
            self.args.multi_scale_eval,
            'crop':
            self.args.crop_eval,
            'num_class':
            self.nclass,
            'scales':
            self.args.eval_scales,
            'base_size':
            self.args.base_size,
        }
        vis_im_list(self.single_device_model, self.vis_im_paths, cfg)

    def infer_and_save(self, infer_dir, infer_save_dir):
        print('=> Infer Dir {}'.format(infer_dir))
        print('=> Save Dir {}'.format(infer_save_dir))
        sub_im_paths = list(
            walkdir(infer_dir, exts=['.jpg', '.png'], sub_path=True))
        im_paths = [os.path.join(infer_dir, p) for p in sub_im_paths]
        # NOTE: Don't save result as JPEG, since it causes aliasing.
        save_paths = [
            os.path.join(infer_save_dir, p.replace('.jpg', '.png'))
            for p in sub_im_paths
        ]
        cfg = {
            'multi_scale': self.args.multi_scale_eval,
            'crop': self.args.crop_eval,
            'num_class': self.nclass,
            'scales': self.args.eval_scales,
            'base_size': self.args.base_size,
        }
        infer_and_save_im_list(self.single_device_model, im_paths, save_paths,
                               cfg)
Example #4
0
 def __init__(self, args):
     self.args = args
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])
     ])
     # dataset
     data_kwargs = {
         'transform': input_transform,
         'base_size': args.base_size,
         'crop_size': args.crop_size
     }
     trainset = get_dataset(args.dataset,
                            split=args.train_split,
                            mode='train',
                            **data_kwargs)
     valset = get_dataset(
         args.dataset,
         split='val',
         mode='ms_val' if args.multi_scale_eval else 'fast_val',
         **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True}
     self.trainloader = data.DataLoader(trainset,
                                        batch_size=args.batch_size,
                                        drop_last=True,
                                        shuffle=True,
                                        **kwargs)
     if self.args.multi_scale_eval:
         kwargs['collate_fn'] = test_batchify_fn
     self.valloader = data.DataLoader(valset,
                                      batch_size=args.test_batch_size,
                                      drop_last=False,
                                      shuffle=False,
                                      **kwargs)
     self.nclass = trainset.num_class
     # model
     if args.norm_layer == 'bn':
         norm_layer = BatchNorm2d
     elif args.norm_layer == 'sync_bn':
         assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!"
         norm_layer = SyncBatchNorm
     else:
         raise ValueError('Invalid norm_layer {}'.format(args.norm_layer))
     model = get_segmentation_model(
         args.model,
         dataset=args.dataset,
         backbone=args.backbone,
         aux=args.aux,
         se_loss=args.se_loss,
         norm_layer=norm_layer,
         base_size=args.base_size,
         crop_size=args.crop_size,
         multi_grid=True,
         multi_dilation=[2, 4, 8],
         only_pam=True,
     )
     print(model)
     # optimizer using different LR
     params_list = [
         {
             'params': model.pretrained.parameters(),
             'lr': args.lr
         },
     ]
     if hasattr(model, 'head'):
         params_list.append({
             'params': model.head.parameters(),
             'lr': args.lr
         })
     if hasattr(model, 'auxlayer'):
         params_list.append({
             'params': model.auxlayer.parameters(),
             'lr': args.lr
         })
     optimizer = torch.optim.SGD(params_list,
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
     # criterions
     self.criterion = SegmentationMultiLosses()
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.multi_gpu:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     else:
         self.model = self.model.cuda()
         self.criterion = self.criterion.cuda()
     self.single_device_model = self.model.module if self.args.multi_gpu else self.model
     # resuming checkpoint
     if args.resume is not None:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'".format(
                 args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         self.single_device_model.load_state_dict(checkpoint['state_dict'])
         if not args.ft and not (args.only_val or args.only_vis
                                 or args.only_infer):
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format(
             args.resume, checkpoint['epoch'], checkpoint['best_pred']))
     # clear start epoch if fine-tuning
     if args.ft:
         args.start_epoch = 0
     # lr scheduler
     self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
         optimizer, 0.6)
     self.best_pred = 0.0
Example #5
0
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'logger': self.logger,
            'scale': args.scale
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class

        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        #print(model)
        self.logger.info(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })

        cityscape_weight = torch.FloatTensor([
            0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
            0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
            1.0865, 1.1529, 1.0507
        ])

        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        #weight for class imbalance
        # self.criterion = SegmentationMultiLosses(nclass=self.nclass, weight=cityscape_weight)
        self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass)

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler,
                                            args.lr,
                                            args.epochs,
                                            len(self.trainloader),
                                            logger=self.logger,
                                            lr_step=args.lr_step)
        self.best_pred = 0.0
Example #6
0
 def __init__(self, args):
     self.args = args
     args.log_name = str(args.checkname)
     args.log_root = os.path.join(args.dataset, args.log_root) # dataset/log/
     self.logger = utils.create_logger(args.log_root, args.log_name)
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])])
     # dataset
     data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                    'crop_size': args.crop_size, 'logger': self.logger,
                    'scale': args.scale}
     trainset = get_segmentation_dataset(args.dataset, split='trainval', mode='trainval',
                                         **data_kwargs)
     testset = get_segmentation_dataset(args.dataset, split='val', mode='val',  # crop fixed size as model input
                                        **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True} \
         if args.cuda else {}
     self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
                                        drop_last=True, shuffle=True, **kwargs)
     self.valloader = data.DataLoader(testset, batch_size=args.batch_size,
                                      drop_last=False, shuffle=False, **kwargs)
     self.nclass = trainset.num_class
     # model
     model = get_segmentation_model(args.model, dataset=args.dataset,
                                    backbone=args.backbone,
                                    norm_layer=BatchNorm2d,
                                    base_size=args.base_size, crop_size=args.crop_size,
                                    )
     #print(model)
     self.logger.info(model)
     # optimizer using different LR
     params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
     if hasattr(model, 'head'):
         print("this model has object, head")
         params_list.append({'params': model.head.parameters(), 'lr': args.lr*10})
     optimizer = torch.optim.SGD(params_list,
                 lr=args.lr,
                 momentum=args.momentum,
                 weight_decay=args.weight_decay)
     self.criterion = SegmentationLosses(nclass=self.nclass)
     
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.cuda:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     
     # resuming checkpoint
     if args.resume:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         if args.cuda:
             self.model.module.load_state_dict(checkpoint['state_dict'])
         else:
             self.model.load_state_dict(checkpoint['state_dict'])
         if not args.ft:
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
     # lr scheduler
     self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                         args.epochs, len(self.trainloader), logger=self.logger,
                                         lr_step=args.lr_step)
     self.best_pred = 0.0
Example #7
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()
Example #8
0
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=False,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       multi_grid=args.multi_grid,
                                       stride=args.stride,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=SyncBatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # print(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'jpu') and model.jpu:
            params_list.append({
                'params': model.jpu.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        class_balance_weight = 'None'
        if args.dataset == "pcontext60":
            class_balance_weight = torch.tensor([
                1.3225e-01, 2.0757e+00, 1.8146e+01, 5.5052e+00, 2.2060e+00,
                2.8054e+01, 2.0566e+00, 1.8598e+00, 2.4027e+00, 9.3435e+00,
                3.5990e+00, 2.7487e-01, 1.4216e+00, 2.4986e+00, 7.7258e-01,
                4.9020e-01, 2.9067e+00, 1.2197e+00, 2.2744e+00, 2.0444e+01,
                3.0057e+00, 1.8167e+01, 3.7405e+00, 5.6749e-01, 3.2631e+00,
                1.5007e+00, 5.5519e-01, 1.0056e+01, 1.8952e+01, 2.6792e-01,
                2.7479e-01, 1.8309e+00, 2.0428e+01, 1.4788e+01, 1.4908e+00,
                1.9113e+00, 2.6166e+02, 2.3233e-01, 1.9096e+01, 6.7025e+00,
                2.8756e+00, 6.8804e-01, 4.4140e+00, 2.5621e+00, 4.4409e+00,
                4.3821e+00, 1.3774e+01, 1.9803e-01, 3.6944e+00, 1.0397e+00,
                2.0601e+00, 5.5811e+00, 1.3242e+00, 3.0088e-01, 1.7344e+01,
                2.1569e+00, 2.7216e-01, 5.8731e-01, 1.9956e+00, 4.4004e+00
            ])

        elif args.dataset == "ade20k":
            class_balance_weight = torch.tensor([
                0.0772, 0.0431, 0.0631, 0.0766, 0.1095, 0.1399, 0.1502, 0.1702,
                0.2958, 0.3400, 0.3738, 0.3749, 0.4059, 0.4266, 0.4524, 0.5725,
                0.6145, 0.6240, 0.6709, 0.6517, 0.6591, 0.6818, 0.9203, 0.9965,
                1.0272, 1.0967, 1.1202, 1.2354, 1.2900, 1.5038, 1.5160, 1.5172,
                1.5036, 2.0746, 2.1426, 2.3159, 2.2792, 2.6468, 2.8038, 2.8777,
                2.9525, 2.9051, 3.1050, 3.1785, 3.3533, 3.5300, 3.6120, 3.7006,
                3.6790, 3.8057, 3.7604, 3.8043, 3.6610, 3.8268, 4.0644, 4.2698,
                4.0163, 4.0272, 4.1626, 4.3702, 4.3144, 4.3612, 4.4389, 4.5612,
                5.1537, 4.7653, 4.8421, 4.6813, 5.1037, 5.0729, 5.2657, 5.6153,
                5.8240, 5.5360, 5.6373, 6.6972, 6.4561, 6.9555, 7.9239, 7.3265,
                7.7501, 7.7900, 8.0528, 8.5415, 8.1316, 8.6557, 9.0550, 9.0081,
                9.3262, 9.1391, 9.7237, 9.3775, 9.4592, 9.7883, 10.6705,
                10.2113, 10.5845, 10.9667, 10.8754, 10.8274, 11.6427, 11.0687,
                10.8417, 11.0287, 12.2030, 12.8830, 12.5082, 13.0703, 13.8410,
                12.3264, 12.9048, 12.9664, 12.3523, 13.9830, 13.8105, 14.0345,
                15.0054, 13.9801, 14.1048, 13.9025, 13.6179, 17.0577, 15.8351,
                17.7102, 17.3153, 19.4640, 17.7629, 19.9093, 16.9529, 19.3016,
                17.6671, 19.4525, 20.0794, 18.3574, 19.1219, 19.5089, 19.2417,
                20.2534, 20.0332, 21.7496, 21.5427, 20.3008, 21.1942, 22.7051,
                23.3359, 22.4300, 20.9934, 26.9073, 31.7362, 30.0784
            ])
        elif args.dataset == "cocostuff":
            class_balance_weight = torch.tensor([
                4.8557e-02, 6.4709e-02, 3.9255e+00, 9.4797e-01, 1.2703e+00,
                1.4151e+00, 7.9733e-01, 8.4903e-01, 1.0751e+00, 2.4001e+00,
                8.9736e+00, 5.3036e+00, 6.0410e+00, 9.3285e+00, 1.5952e+00,
                3.6090e+00, 9.8772e-01, 1.2319e+00, 1.9194e+00, 2.7624e+00,
                2.0548e+00, 1.2058e+00, 3.6424e+00, 2.0789e+00, 1.7851e+00,
                6.7138e+00, 2.1315e+00, 6.9813e+00, 1.2679e+02, 2.0357e+00,
                2.2933e+01, 2.3198e+01, 1.7439e+01, 4.1294e+01, 7.8678e+00,
                4.3444e+01, 6.7543e+01, 1.0066e+01, 6.7520e+00, 1.3174e+01,
                3.3499e+00, 6.9737e+00, 2.1482e+00, 1.9428e+01, 1.3240e+01,
                1.9218e+01, 7.6836e-01, 2.6041e+00, 6.1822e+00, 1.4070e+00,
                4.4074e+00, 5.7792e+00, 1.0321e+01, 4.9922e+00, 6.7408e-01,
                3.1554e+00, 1.5832e+00, 8.9685e-01, 1.1686e+00, 2.6487e+00,
                6.5354e-01, 2.3801e-01, 1.9536e+00, 1.5862e+00, 1.7797e+00,
                2.7385e+01, 1.2419e+01, 3.9287e+00, 7.8897e+00, 7.5737e+00,
                1.9758e+00, 8.1962e+01, 3.6922e+00, 2.0039e+00, 2.7333e+00,
                5.4717e+00, 3.9048e+00, 1.9184e+01, 2.2689e+00, 2.6091e+02,
                4.7366e+01, 2.3844e+00, 8.3310e+00, 1.4857e+01, 6.5076e+00,
                2.0854e-01, 1.0425e+00, 1.7386e+00, 1.1973e+01, 5.2862e+00,
                1.7341e+00, 8.6124e-01, 9.3702e+00, 2.8545e+00, 6.0123e+00,
                1.7560e-01, 1.8128e+00, 1.3784e+00, 1.3699e+00, 2.3728e+00,
                6.2819e-01, 1.3097e+00, 4.7892e-01, 1.0268e+01, 1.2307e+00,
                5.5662e+00, 1.2867e+00, 1.2745e+00, 4.7505e+00, 8.4029e+00,
                1.8679e+00, 1.0519e+01, 1.1240e+00, 1.4975e-01, 2.3146e+00,
                4.1265e-01, 2.5896e+00, 1.4537e+00, 4.5575e+00, 7.8143e+00,
                1.4603e+01, 2.8812e+00, 1.8868e+00, 7.8131e+01, 1.9323e+00,
                7.4980e+00, 1.2446e+01, 2.1856e+00, 3.0973e+00, 4.1270e-01,
                4.9016e+01, 7.1001e-01, 7.4035e+00, 2.3395e+00, 2.9207e-01,
                2.4156e+00, 3.3211e+00, 2.1300e+00, 2.4533e-01, 1.7081e+00,
                4.6621e+00, 2.9199e+00, 1.0407e+01, 7.6207e-01, 2.7806e-01,
                3.7711e+00, 1.1852e-01, 8.8280e+00, 3.1700e-01, 6.3765e+01,
                6.6032e+00, 5.2177e+00, 4.3596e+00, 6.2965e-01, 1.0207e+00,
                1.1731e+01, 2.3935e+00, 9.2767e+00, 1.1023e-01, 3.6947e+00,
                1.3943e+00, 2.3407e+00, 1.2112e-01, 2.8518e+00, 2.8195e+00,
                1.0078e+00, 1.6614e+00, 6.5307e-01, 1.9070e+01, 2.7231e+00,
                6.0769e-01
            ])

        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                            aux=args.aux,
                                            nclass=self.nclass,
                                            se_weight=args.se_weight,
                                            aux_weight=args.aux_weight,
                                            weight=class_balance_weight)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader))
def train(args):
    weight_dir = args.log_root  # os.path.join(args.log_root, 'weights')
    log_dir = os.path.join(
        args.log_root, 'logs',
        'SS-Net-{}'.format(time.strftime("%Y-%m-%d-%H-%M-%S",
                                         time.localtime())))

    data_dir = os.path.join(args.data_root, args.dataset)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup DataLoader
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    net_h, net_w = int(args.img_row * args.crop_ratio), int(args.img_col *
                                                            args.crop_ratio)
    augment_train = Compose([
        RandomHorizontallyFlip(),
        RandomSized((0.5, 0.75)),
        RandomRotate(5),
        RandomCrop((net_h, net_w))
    ])
    augment_valid = Compose([
        RandomHorizontallyFlip(),
        Scale((args.img_row, args.img_col)),
        CenterCrop((net_h, net_w))
    ])

    train_loader = CityscapesLoader(data_dir,
                                    gt='gtFine',
                                    split='train',
                                    img_size=(args.img_row, args.img_col),
                                    is_transform=True,
                                    augmentations=augment_train)

    valid_loader = CityscapesLoader(data_dir,
                                    gt='gtFine',
                                    split='val',
                                    img_size=(args.img_row, args.img_col),
                                    is_transform=True,
                                    augmentations=augment_valid)

    num_classes = train_loader.n_classes

    tra_loader = data.DataLoader(train_loader,
                                 batch_size=args.batch_size,
                                 num_workers=int(multiprocessing.cpu_count() /
                                                 2),
                                 shuffle=True)

    val_loader = data.DataLoader(valid_loader,
                                 batch_size=args.batch_size,
                                 num_workers=int(multiprocessing.cpu_count() /
                                                 2))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")
    model = RetinaNet(num_classes=num_classes, input_size=(net_h, net_w))
    # model = torch.nn.DataParallel(model, device_ids=[0,1,2]).cuda()
    model = DataParallelModel(model,
                              device_ids=args.device_ids).cuda()  # multi-gpu

    # 2.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer
    if hasattr(model.module, 'optimizer'):
        print('> Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.learning_rate,
                                    momentum=0.90,
                                    weight_decay=5e-4,
                                    nesterov=True)
        # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)

    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    # 2.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    class_weight = np.array([
        0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437,
        5.58540548, 3.56563995, 0.12704978, 1., 0.46783719, 1.34551528,
        5.29974114, 0.28342531, 0.9396095, 0.81551811, 0.42679146, 3.6399074,
        2.78376194
    ],
                            dtype=float)
    class_weight = torch.from_numpy(class_weight).float().cuda()

    sem_loss = bootstrapped_cross_entropy2d
    sem_loss = DataParallelCriterion(sem_loss, device_ids=args.device_ids)
    se_loss = SemanticEncodingLoss(num_classes=19,
                                   ignore_label=250,
                                   alpha=0.50).cuda()
    se_loss_parallel = DataParallelCriterion(se_loss,
                                             device_ids=args.device_ids)
    """
    # multi-gpu
    bootstrapped_cross_entropy2d = ContextBootstrappedCELoss2D(num_classes=num_classes,
                                                               ignore=250,
                                                               kernel_size=5,
                                                               padding=4,
                                                               dilate=2,
                                                               use_gpu=True)
    loss_sem = DataParallelCriterion(bootstrapped_cross_entropy2d, device_ids=[0, 1]) 
    """

    # 2.3 Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # !!!!! Here Metrics !!!!!
    metrics = RunningScore(num_classes)  # num_classes = 93

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 2. Model state init or resume...")
    args.start_epoch = 1
    args.start_iter = 0
    beat_map = 0.
    if args.resume is not None:
        full_path = os.path.join(os.path.join(weight_dir, 'train_model'),
                                 args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(
                args.resume))

            checkpoint = torch.load(full_path)

            args.start_epoch = checkpoint['epoch']
            args.start_iter = checkpoint['iter']
            beat_map = checkpoint['beat_map']
            model.load_state_dict(checkpoint['model_state'])  # weights
            optimizer.load_state_dict(
                checkpoint['optimizer_state'])  # gradient state
            del checkpoint

            print("> Loaded checkpoint '{}' (epoch {}, iter {})".format(
                args.resume, args.start_epoch, args.start_iter))

        else:
            print("> No checkpoint found at '{}'".format(full_path))
            raise Exception("> No checkpoint found at '{}'".format(full_path))
    else:
        # init_weights(model, pi=0.01,
        #              pre_trained=os.path.join(args.log_root, 'resnet50_imagenet.pth'))

        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(
                args.pre_trained))
            full_path = os.path.join(args.log_root, args.pre_trained)

            pre_weight = torch.load(full_path)
            prefix = "module.fpn.base_net."

            model_dict = model.state_dict()
            pretrained_dict = {(prefix + k): v
                               for k, v in pre_weight.items()
                               if (prefix + k) in model_dict}

            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4.0. Setup tensor-board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="SSnet_Cityscapes")
        # dummy_input = Variable(torch.rand(1, 3, args.img_row, args.img_col).cuda(), requires_grad=True)
        # writer.add_graph(model, dummy_input)

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 3. Model Training start...")
    topk_init = 512
    num_batches = int(
        math.ceil(
            len(tra_loader.dataset.files[tra_loader.dataset.split]) /
            float(tra_loader.batch_size)))

    # lr_period = 20 * num_batches

    for epoch in np.arange(args.start_epoch - 1, args.num_epochs):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.1 Mini-Batch Training
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        model.train()
        topk_base = topk_init

        if epoch == args.start_epoch - 1:
            pbar = tqdm(np.arange(args.start_iter, num_batches))
            start_iter = args.start_iter
        else:
            pbar = tqdm(np.arange(num_batches))
            start_iter = 0

        lr = args.learning_rate

        # lr = adjust_learning_rate(optimizer, init_lr=args.learning_rate, decay_rate=0.1, curr_epoch=epoch,
        #                           epoch_step=20, start_decay_at_epoch=args.start_decay_at_epoch,
        #                           total_epoch=args.num_epochs, mode='exp')

        # scheduler.step()
        # for train_i, (images, gt_masks) in enumerate(tra_loader):  # One mini-Batch datasets, One iteration
        for train_i, (images, gt_masks) in zip(range(start_iter, num_batches),
                                               tra_loader):

            full_iter = (epoch * num_batches) + train_i + 1

            lr = poly_lr_scheduler(optimizer,
                                   init_lr=args.learning_rate,
                                   iter=full_iter,
                                   lr_decay_iter=1,
                                   max_iter=args.num_epochs * num_batches,
                                   power=0.9)

            # lr = args.learning_rate * cosine_annealing_lr(lr_period, full_iter)
            # optimizer = set_optimizer_lr(optimizer, lr)

            images = images.cuda().requires_grad_()
            se_labels = se_loss.unique_encode(gt_masks)
            se_labels = se_labels.cuda()
            gt_masks = gt_masks.cuda()

            topk_base = poly_topk_scheduler(init_topk=topk_init,
                                            iter=full_iter,
                                            topk_decay_iter=1,
                                            max_iter=args.num_epochs *
                                            num_batches,
                                            power=0.95)

            optimizer.zero_grad()

            se, sem_seg_pred = model(images)

            # --------------------------------------------------- #
            # Compute loss
            # --------------------------------------------------- #
            topk = topk_base * 512
            train_loss = sem_loss(input=sem_seg_pred,
                                  target=gt_masks,
                                  K=topk,
                                  weight=None)
            train_se_loss = se_loss_parallel(predicts=se,
                                             enc_cls_target=se_labels,
                                             size_average=True,
                                             reduction='elementwise_mean')

            loss = train_loss + args.alpha * train_se_loss
            loss.backward()  # back-propagation

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e3)
            optimizer.step()  # parameter update based on the current gradient

            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" %
                                 (epoch + 1, args.num_epochs))
            pbar.set_postfix(Train_Loss=train_loss.item(),
                             Train_SE_Loss=train_se_loss.item(),
                             TopK=topk_base)
            # pbar.set_postfix(Train_Loss=train_loss.item(), TopK=topk_base)

            # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
            # 4.1.1 Verbose training process
            # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
            if (train_i + 1) % args.verbose_interval == 0:
                # ---------------------------------------- #
                # 1. Training Losses
                # ---------------------------------------- #
                loss_log = "Epoch [%d/%d], Iter: %d Loss1: \t %.4f " % (
                    epoch + 1, args.num_epochs, train_i + 1, loss.item())

                # ---------------------------------------- #
                # 2. Training Metrics
                # ---------------------------------------- #
                sem_seg_pred = F.softmax(sem_seg_pred, dim=1)
                pred = sem_seg_pred.data.max(1)[1].cpu().numpy()
                gt = gt_masks.data.cpu().numpy()

                metrics.update(
                    gt,
                    pred)  # accumulate the metrics (confusion_matrix and ious)
                score, _ = metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                metrics.reset()  # reset the metrics for each train_i steps

                logs = loss_log + metric_log

                if args.tensor_board:
                    writer.add_scalar('Training/Train_Loss', train_loss.item(),
                                      full_iter)
                    writer.add_scalar('Training/Train_SE_Loss',
                                      train_se_loss.item(), full_iter)
                    writer.add_scalar('Training/Loss', loss.item(), full_iter)
                    writer.add_scalar('Training/Lr', lr, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)
                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name,
                                             param.clone().cpu().data.numpy(),
                                             full_iter)
            """
            # each 2000 iterations save model
            if (train_i + 1) % args.iter_interval_save_model == 0:
                pbar.set_postfix(Loss=train_loss.item(), lr=lr)

                state = {"epoch": epoch + 1,
                         "iter": train_i + 1,
                         'beat_map': beat_map,
                         "model_state": model.state_dict(),
                         "optimizer_state": optimizer.state_dict()}

                save_dir = os.path.join(os.path.join(weight_dir, 'train_model'),
                                        "ssnet_model_sem_se_{}epoch_{}iter.pkl".format(epoch+1, train_i+1))
                torch.save(state, save_dir)
            """

        # end of this training phase
        state = {
            "epoch": epoch + 1,
            "iter": num_batches,
            'beat_map': beat_map,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict()
        }

        save_dir = os.path.join(
            os.path.join(args.log_root, 'train_model'),
            "ssnet_model_sem_se_{}_{}epoch_{}iter.pkl".format(
                args.model_details, epoch + 1, num_batches))
        torch.save(state, save_dir)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        model.eval()

        val_loss = 0.0
        vali_count = 0

        with torch.no_grad():
            for i_val, (images_val, gt_masks_val) in enumerate(val_loader):
                vali_count += 1

                images_val = images_val.cuda()
                se_labels_val = se_loss.unique_encode(gt_masks_val)
                se_labels_val = se_labels_val.cuda()
                gt_masks_val = gt_masks_val.cuda()

                se_val, sem_seg_pred_val = model(images_val)

                # !!!!!! Loss !!!!!!
                topk_val = topk_base * 512
                loss = sem_loss(sem_seg_pred_val, gt_masks_val, topk_val, weight=None) + \
                       args.alpha * se_loss_parallel(predicts=se_val, enc_cls_target=se_labels_val,
                                                     size_average=True, reduction='elementwise_mean')
                val_loss += loss.item()

                # accumulating the confusion matrix and ious
                sem_seg_pred_val = F.softmax(sem_seg_pred_val, dim=1)
                pred = sem_seg_pred_val.data.max(1)[1].cpu().numpy()
                gt = gt_masks_val.data.cpu().numpy()
                metrics.update(gt, pred)

            # ---------------------------------------- #
            # 1. Validation Losses
            # ---------------------------------------- #
            val_loss /= vali_count

            loss_log = "Epoch [%d/%d], Loss: \t %.4f" % (
                epoch + 1, args.num_epochs, val_loss)

            # ---------------------------------------- #
            # 2. Validation Metrics
            # ---------------------------------------- #
            metric_log = ""
            score, _ = metrics.get_scores()
            for k, v in score.items():
                metric_log += " {}: \t %.4f, ".format(k) % v
            metrics.reset()  # reset the metrics

            logs = loss_log + metric_log

            pbar.set_postfix(
                Vali_Loss=val_loss, Lr=lr,
                Vali_mIoU=score['Mean_IoU'])  # Train_Loss=train_loss.item()

            if args.tensor_board:
                writer.add_scalar('Validation/Loss', val_loss, epoch)
                writer.add_scalars('Validation/Metrics', score, epoch)
                writer.add_text('Validation/Text', logs, epoch)

                for name, param in model.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         epoch)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.3 End of one Epoch
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # !!!!! Here choose suitable Metric for the best model selection !!!!!

        if score['Mean_IoU'] >= beat_map:
            beat_map = score['Mean_IoU']
            state = {
                "epoch": epoch + 1,
                "beat_map": beat_map,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict()
            }

            save_dir = os.path.join(
                weight_dir,
                "SSnet_best_sem_se_{}_model.pkl".format(args.model_details))
            torch.save(state, save_dir)

        # Note that step should be called after validate()
        pbar.close()

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4.4 End of Training process
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    if args.tensor_board:
        # export scalar datasets to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
Example #10
0
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        root_dir = getattr(args, "data_root", '../datasets')
        wo_head = getattr(args, "resume_wo_head", False)

        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'logger': self.logger,
            'scale': args.scale
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            root=root_dir,
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           root=root_dir,
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class

        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        #print(model)
        self.logger.info(model)
        # optimizer using different LR

        if not args.wo_backbone:
            params_list = [
                {
                    'params': model.pretrained.parameters(),
                    'lr': args.lr
                },
            ]
        else:
            params_list = []

        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass)

        self.model, self.optimizer = model, optimizer

        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()

        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if wo_head:
                print("WITHout HEAD !!!!!!!!!!")
                from collections import OrderedDict
                new = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    if not k.startswith("head"):
                        new[k] = v
                checkpoint['state_dict'] = new
            else:
                print("With HEAD !!!!!!!!!!")

            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler,
                                            args.lr,
                                            args.epochs,
                                            len(self.trainloader),
                                            logger=self.logger,
                                            lr_step=args.lr_step)
        self.best_pred = 0.0
Example #11
0
    def __init__(self, args):
        self.args = args
        self.args.start_epoch = 0
        self.args.cuda = True
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.490, .490, .490], [.247, .247, .247])
        ])  # TODO: change mean and std

        # dataset
        train_chain = Compose([
            HorizontalFlip(p=0.5),
            OneOf([
                ElasticTransform(
                    alpha=300, sigma=300 * 0.05, alpha_affine=300 * 0.03),
                GridDistortion(),
                OpticalDistortion(distort_limit=2, shift_limit=0.5),
            ],
                  p=0.3),
            RandomSizedCrop(
                min_max_height=(900, 1024), height=1024, width=1024, p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
            Resize(self.args.size, self.args.size)
        ],
                              p=1)

        val_chain = Compose([Resize(self.args.size, self.args.size)], p=1)
        num_fold = self.args.num_fold
        df_train = pd.read_csv(os.path.join(args.imagelist_path, 'train.csv'))
        df_val = pd.read_csv(os.path.join(args.imagelist_path, 'val.csv'))
        df_full = pd.concat((df_train, df_val), ignore_index=True, axis=0)
        df_full['lbl'] = (df_full['mask_name'].astype(str) == '-1').astype(int)
        skf = StratifiedKFold(8, shuffle=True, random_state=777)
        train_ids, val_ids = list(
            skf.split(df_full['mask_name'], df_full['lbl']))[num_fold]

        df_test = pd.read_csv(
            os.path.join(args.imagelist_path, 'test_true.csv'))

        df_new_train = pd.concat((df_full.iloc[train_ids], df_test),
                                 ignore_index=True,
                                 axis=0,
                                 sort=False)
        df_new_val = df_full.iloc[val_ids]

        df_new_train.to_csv(f'/tmp/train_new_pneumo_{num_fold}.csv')
        df_new_val.to_csv(f'/tmp/val_new_pneumo_{num_fold}.csv')

        trainset = SegmentationDataset(f'/tmp/train_new_pneumo_{num_fold}.csv',
                                       args.image_path,
                                       args.masks_path,
                                       input_transform=input_transform,
                                       transform_chain=train_chain,
                                       base_size=1024)
        testset = SegmentationDataset(f'/tmp/val_new_pneumo_{num_fold}.csv',
                                      args.image_path,
                                      args.masks_path,
                                      input_transform=input_transform,
                                      transform_chain=val_chain,
                                      base_size=1024)

        imgs = trainset.mask_img_map[:, [0, 3]]
        weights = make_weights_for_balanced_classes(imgs, 2)
        weights = torch.DoubleTensor(weights)
        train_sampler = (torch.utils.data.sampler.WeightedRandomSampler(
            weights, len(weights)))

        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.trainloader = data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=train_sampler,  #shuffle=True, 
            **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)

        self.nclass = 1
        if self.args.model == 'unet':
            model = UNet(n_classes=self.nclass, norm_layer=SyncBatchNorm)
            params_list = [
                {
                    'params': model.parameters(),
                    'lr': args.lr
                },
            ]
        elif self.args.model == 'encnet':
            model = EncNet(
                nclass=self.nclass,
                backbone=args.backbone,
                aux=args.aux,
                se_loss=args.se_loss,
                norm_layer=SyncBatchNorm  #nn.BatchNorm2d
            )

            # optimizer using different LR
            params_list = [
                {
                    'params': model.pretrained.parameters(),
                    'lr': args.lr
                },
            ]
            if hasattr(model, 'head'):
                params_list.append({
                    'params': model.head.parameters(),
                    'lr': args.lr * 10
                })
            if hasattr(model, 'auxlayer'):
                params_list.append({
                    'params': model.auxlayer.parameters(),
                    'lr': args.lr * 10
                })

        print(model)
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.wd)

        # criterions
        if self.nclass == 1:
            self.criterion = SegmentationLossesBCE(se_loss=args.se_loss,
                                                   aux=args.aux,
                                                   nclass=self.nclass,
                                                   se_weight=args.se_weight,
                                                   aux_weight=args.aux_weight,
                                                   use_dice=args.use_dice)
        else:
            self.criterion = SegmentationLosses(
                se_loss=args.se_loss,
                aux=args.aux,
                nclass=self.nclass,
                se_weight=args.se_weight,
                aux_weight=args.aux_weight,
            )
        self.model, self.optimizer = model, optimizer

        self.best_pred = 0.0
        self.model = DataParallelModel(self.model).cuda()
        self.criterion = DataParallelCriterion(self.criterion).cuda()

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)  #, map_location='cpu')
            self.args.start_epoch = checkpoint['epoch']
            state_dict = {k: v for k, v in checkpoint['state_dict'].items()}
            self.model.load_state_dict(state_dict)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            for g in self.optimizer.param_groups:
                g['lr'] = args.lr
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            print(f'Best dice: {checkpoint["best_pred"]}')
            print(f'LR: {get_lr(self.optimizer):.5f}')

        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           mode='min',
                                           factor=0.8,
                                           patience=4,
                                           threshold=0.001,
                                           threshold_mode='abs',
                                           min_lr=0.00001)
        self.logger = Logger(args.logger_dir)
        self.step_train = 0
        self.best_loss = 20
        self.step_val = 0