Пример #1
0
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'logger': self.logger,
            'scale': args.scale
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        #print(model)
        self.logger.info(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass)

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler,
                                            args.lr,
                                            args.epochs,
                                            len(self.trainloader),
                                            logger=self.logger,
                                            lr_step=args.lr_step)
        self.best_pred = 0.0
Пример #2
0
 def __init__(self, args):
     self.args = args
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])
     ])
     # dataset
     data_kwargs = {
         'transform': input_transform,
         'base_size': args.base_size,
         'crop_size': args.crop_size
     }
     trainset = get_dataset(args.dataset,
                            split=args.train_split,
                            mode='train',
                            **data_kwargs)
     valset = get_dataset(
         args.dataset,
         split='val',
         mode='ms_val' if args.multi_scale_eval else 'fast_val',
         **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True}
     self.trainloader = data.DataLoader(trainset,
                                        batch_size=args.batch_size,
                                        drop_last=True,
                                        shuffle=True,
                                        **kwargs)
     if self.args.multi_scale_eval:
         kwargs['collate_fn'] = test_batchify_fn
     self.valloader = data.DataLoader(valset,
                                      batch_size=args.test_batch_size,
                                      drop_last=False,
                                      shuffle=False,
                                      **kwargs)
     self.nclass = trainset.num_class
     # model
     if args.norm_layer == 'bn':
         norm_layer = BatchNorm2d
     elif args.norm_layer == 'sync_bn':
         assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!"
         norm_layer = SyncBatchNorm
     else:
         raise ValueError('Invalid norm_layer {}'.format(args.norm_layer))
     model = get_segmentation_model(
         args.model,
         dataset=args.dataset,
         backbone=args.backbone,
         aux=args.aux,
         se_loss=args.se_loss,
         norm_layer=norm_layer,
         base_size=args.base_size,
         crop_size=args.crop_size,
         multi_grid=True,
         multi_dilation=[2, 4, 8],
         only_pam=True,
     )
     print(model)
     # optimizer using different LR
     params_list = [
         {
             'params': model.pretrained.parameters(),
             'lr': args.lr
         },
     ]
     if hasattr(model, 'head'):
         params_list.append({
             'params': model.head.parameters(),
             'lr': args.lr
         })
     if hasattr(model, 'auxlayer'):
         params_list.append({
             'params': model.auxlayer.parameters(),
             'lr': args.lr
         })
     optimizer = torch.optim.SGD(params_list,
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
     # criterions
     self.criterion = SegmentationMultiLosses()
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.multi_gpu:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     else:
         self.model = self.model.cuda()
         self.criterion = self.criterion.cuda()
     self.single_device_model = self.model.module if self.args.multi_gpu else self.model
     # resuming checkpoint
     if args.resume is not None:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'".format(
                 args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         self.single_device_model.load_state_dict(checkpoint['state_dict'])
         if not args.ft and not (args.only_val or args.only_vis
                                 or args.only_infer):
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format(
             args.resume, checkpoint['epoch'], checkpoint['best_pred']))
     # clear start epoch if fine-tuning
     if args.ft:
         args.start_epoch = 0
     # lr scheduler
     self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
         optimizer, 0.6)
     self.best_pred = 0.0