def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) #print(model) self.logger.info(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.criterion = SegmentationMultiLosses(nclass=self.nclass) #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0
def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) valset = get_dataset( args.dataset, split='val', mode='ms_val' if args.multi_scale_eval else 'fast_val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) if self.args.multi_scale_eval: kwargs['collate_fn'] = test_batchify_fn self.valloader = data.DataLoader(valset, batch_size=args.test_batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model if args.norm_layer == 'bn': norm_layer = BatchNorm2d elif args.norm_layer == 'sync_bn': assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!" norm_layer = SyncBatchNorm else: raise ValueError('Invalid norm_layer {}'.format(args.norm_layer)) model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=norm_layer, base_size=args.base_size, crop_size=args.crop_size, multi_grid=True, multi_dilation=[2, 4, 8], only_pam=True, ) print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationMultiLosses() self.model, self.optimizer = model, optimizer # using cuda if args.multi_gpu: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() else: self.model = self.model.cuda() self.criterion = self.criterion.cuda() self.single_device_model = self.model.module if self.args.multi_gpu else self.model # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] self.single_device_model.load_state_dict(checkpoint['state_dict']) if not args.ft and not (args.only_val or args.only_vis or args.only_infer): self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format( args.resume, checkpoint['epoch'], checkpoint['best_pred'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, 0.6) self.best_pred = 0.0