def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader val_dataset = get_segmentation_dataset('eyes', split='val', mode='testval', transform=input_transform) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1) self.val_loader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.model = get_segmentation_model(model=args.model, dataset=args.dataset, aux=args.aux, pretrained=True, pretrained_base=False) if args.distributed: self.model = self.model.module self.model.to(self.device) self.metric = SegmentationMetric(val_dataset.num_class)
def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader val60_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } valset = get_segmentation_dataset(args.dataset, args=args, split='val', mode='val_onlyrs', **val60_data_kwargs) val_sampler = make_data_sampler(valset, True, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size) self.val60_loader = data.DataLoader(dataset=valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, args=self.args, norm_layer=BatchNorm2d).to( self.device) self.model = load_model(args.resume, self.model) # evaluation metrics self.metric_120 = SegmentationMetric(valset.num_class) self.metric_60 = SegmentationMetric(valset.num_class) self.best_pred = 0.0
def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) test_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } testset = get_segmentation_dataset(args.dataset, args=args, split='test', mode='test', **test_data_kwargs) test_sampler = make_data_sampler(testset, False, args.distributed) test_batch_sampler = make_batch_data_sampler(test_sampler, args.batch_size) self.test_loader = data.DataLoader(dataset=testset, batch_sampler=test_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, args=self.args, norm_layer=BatchNorm2d).to( self.device) self.model = load_model(args.resume, self.model)
def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) args.iters_per_epoch = len(trainset) // (args.num_gpus * args.batch_size) args.max_iters = args.epochs * args.iters_per_epoch train_sampler = make_data_sampler(trainset, shuffle=True, distributed=args.distributed) train_batch_sampler = make_batch_data_sampler(train_sampler, args.batch_size, args.max_iters) self.train_loader = data.DataLoader(dataset=trainset, batch_sampler=train_batch_sampler, num_workers=args.workers, pin_memory=True) if not args.skip_val: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = make_batch_data_sampler( val_sampler, args.batch_size) self.val_loader = data.DataLoader(dataset=valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux, norm_layer=BatchNorm2d) if args.distributed: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[args.local_rank], output_device=args.local_rank) self.model = self.model.to(args.device) # resume checkpoint if needed if args.resume: if os.path.isfile(args.resume): name, ext = os.path.splitext(args.resume) assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.' print('Resuming training, loading {}...'.format(args.resume)) self.model.load_state_dict( torch.load(args.resume, map_location=lambda storage, loc: storage)) # create criterion if args.ohem: min_kept = int(args.batch_size // args.num_gpus * args.crop_size**2 // 16) self.criterion = MixSoftmaxCrossEntropyOHEMLoss( args.aux, args.aux_weight, min_kept=min_kept, ignore_index=-1).to(self.device) else: self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_index=-1).to( self.device) # optimizer self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # lr scheduling self.lr_scheduler = WarmupPolyLR(self.optimizer, max_iters=args.max_iters, power=0.9, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, warmup_method=args.warmup_method) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.best_pred = 0.0
def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader train_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } trainset = get_segmentation_dataset(args.dataset, args=args, split='train', mode='train_onlyrs', **train_data_kwargs) args.iters_per_epoch = len(trainset) // (args.num_gpus * args.batch_size) args.max_iters = args.epochs * args.iters_per_epoch train_sampler = make_data_sampler(trainset, shuffle=True, distributed=args.distributed) train_batch_sampler = make_batch_data_sampler(train_sampler, args.batch_size, args.max_iters) self.train_loader = data.DataLoader(dataset=trainset, batch_sampler=train_batch_sampler, num_workers=args.workers, pin_memory=True) val60_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } valset = get_segmentation_dataset(args.dataset, args=args, split='val', mode='val_onlyrs', **val60_data_kwargs) val_sampler = make_data_sampler(valset, True, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size) self.val60_loader = data.DataLoader(dataset=valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, args=self.args, norm_layer=BatchNorm2d).to( self.device) self.model = load_modules(args, self.model) self.model = fix_model(args, self.model) # optimizer self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # create criterion if args.ohem: min_kept = int(args.batch_size // args.num_gpus * args.crop_size**2 // 16) self.criterion = MixSoftmaxCrossEntropyOHEMLoss( args.aux, args.aux_weight, min_kept=min_kept, ignore_index=-1).to(self.device) else: self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_index=-1).to( self.device) # lr scheduling self.lr_scheduler = WarmupPolyLR(self.optimizer, max_iters=args.max_iters, power=0.9, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, warmup_method=args.warmup_method) if args.use_DataParallel: self.model = torch.nn.DataParallel(self.model, device_ids=range( torch.cuda.device_count())) elif args.distributed: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # evaluation metrics self.metric_120 = SegmentationMetric(trainset.num_class) self.metric_60 = SegmentationMetric(trainset.num_class) self.best_pred = 0.0