def build_model(configuration): model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+'] if configuration.Model.model_name.lower() == 'unet': return smp.Unet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes, decoder_attention_type=None, ) if configuration.Model.model_name.lower() == 'linknet': return smp.Linknet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pspnet': return smp.PSPNet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'fpn': return smp.FPN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pan': return smp.PAN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3+': return smp.DeepLabV3Plus( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3': return smp.DeepLabV3( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) raise KeyError(f'Model should be one of {model_list}')
def create_smp_model(arch, **kwargs): 'Create segmentation_models_pytorch model' assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}' if arch == "Unet": model = smp.Unet(**kwargs) elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs) elif arch == "MAnet": model = smp.MAnet(**kwargs) elif arch == "FPN": model = smp.FPN(**kwargs) elif arch == "PAN": model = smp.PAN(**kwargs) elif arch == "PSPNet": model = smp.PSPNet(**kwargs) elif arch == "Linknet": model = smp.Linknet(**kwargs) elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs) elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs) else: raise NotImplementedError setattr(model, 'kwargs', kwargs) return model
def __init__(self): self.backbone = 'resnet34' self.lr = 1e-4 self.batch = 64 self.crop_size = 384 logger = logging.getLogger('train') # print settings logger.info('\nTRAINING SETTINGS') logger.info('###########################') logger.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logger.info('backbone:{}, lr:{}, batch:{}, crop size:{}'.format(self.backbone, self.lr, self.batch, self.crop_size)) logger.info('###########################\n') # model define self.train_loader, self.test_loader = make_data_loader(batch=self.batch, crop_size=self.crop_size) model = smp.PAN(encoder_name=self.backbone, encoder_weights='imagenet', in_channels=3, classes=2) self.model = nn.DataParallel(model.cuda()) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.criterion = nn.CrossEntropyLoss() self.epoch = 0
def __init__(self, encoder, encoder_weights, classes, activation, learning_rate=1e-3, **kwargs): super().__init__() self.save_hyperparameters() self.classes = classes if self.hparams.architecture == 'fpn': self.model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation, ) elif self.hparams.architecture == 'pan': self.model = smp.PAN( encoder_name=encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation, ) elif self.hparams.architecture == 'pspnet': self.model = smp.PSPNet( encoder_name=encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation, ) else: raise NameError('') self.loss = smp.utils.losses.DiceLoss()
encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) elif MODEL == 'deeplabv3plus': model = smp.DeepLabV3Plus( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) elif MODEL == 'pannet': model = smp.PAN( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) elif MODEL == 'fpn': model = smp.FPN( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) else: raise RuntimeError('Model name Error') preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) # ------------------------------------------------------
def get_model(config): """ """ arch = config.MODEL.ARCHITECTURE backbone = config.MODEL.BACKBONE encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM in_channels = config.MODEL.IN_CHANNELS n_classes = len(config.INPUT.CLASSES) activation = config.MODEL.ACTIVATION # unet specific decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None if arch == 'unet': model = smp.Unet( encoder_name=backbone, encoder_weights=encoder_weights, decoder_channels=config.MODEL.UNET_DECODER_CHANNELS, decoder_attention_type=decoder_attention_type, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'fpn': model = smp.FPN( encoder_name=backbone, encoder_weights=encoder_weights, decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'pan': model = smp.PAN( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'pspnet': model = smp.PSPNet( encoder_name=backbone, encoder_weights=encoder_weights, psp_dropout=config.MODEL.PSPNET_DROPOUT, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'deeplabv3': model = smp.DeepLabV3( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'linknet': model = smp.Linknet( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) else: raise ValueError() model = torch.nn.DataParallel(model) if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none': # load weight from file model.load_state_dict( torch.load( config.MODEL.WEIGHT, map_location=torch.device('cpu') ) ) model = model.to(config.MODEL.DEVICE) return model
def get_segmentation_model( arch: str, encoder_name: str, encoder_weights: Optional[str] = "imagenet", pretrained_checkpoint_path: Optional[str] = None, checkpoint_path: Optional[Union[str, List[str]]] = None, convert_bn: Optional[str] = None, convert_bottleneck: Tuple[int, int, int] = (0, 0, 0), **kwargs: Any, ) -> nn.Module: """ Fetch segmentation model by its name :param arch: :param encoder_name: :param encoder_weights: :param checkpoint_path: :param pretrained_checkpoint_path: :param convert_bn: :param convert_bottleneck: :param kwargs: :return: """ arch = arch.lower() if (encoder_name == "en_resnet34" or checkpoint_path is not None or pretrained_checkpoint_path is not None): encoder_weights = None if arch == "unet": model = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "unetplusplus" or arch == "unet++": model = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "linknet": model = smp.Linknet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "pspnet": model = smp.PSPNet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "pan": model = smp.PAN(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "deeplabv3": model = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "deeplabv3plus" or arch == "deeplabv3+": model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "manet": model = smp.MAnet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) else: raise ValueError if pretrained_checkpoint_path is not None: print(f"Loading pretrained checkpoint {pretrained_checkpoint_path}") state_dict = torch.load(pretrained_checkpoint_path, map_location=torch.device("cpu")) model.encoder.load_state_dict(state_dict) del state_dict # TODO fmap_size=16 hardcoded for input 256 (matters for positional encoding) botnet.convert_resnet( model.encoder, replacement=convert_bottleneck, fmap_size=16, position_encoding=None, ) # TODO parametrize conversion print(f"Convert BN to {convert_bn}") if convert_bn == "instance": print("Converting BatchNorm2d to InstanceNorm2d") model = batch_norm2instance(model) elif convert_bn == "group": print("Converting BatchNorm2d to GroupNorm") model = batch_norm2group(model, channels_per_group=1) elif convert_bn == "bnet": print("Converting BatchNorm2d to BNet2d") model = batch_norm2bnet(model) elif convert_bn == "gnet": print("Converting BatchNorm2d to GNet2d") model = batch_norm2gnet(model, channels_per_group=1) elif not convert_bn: print("Do not convert BatchNorm2d") else: raise ValueError if checkpoint_path is not None: if not isinstance(checkpoint_path, list): checkpoint_path = [checkpoint_path] states = [] for cp in checkpoint_path: # Load checkpoint print(f"\nLoading checkpoint {str(cp)}") state_dict = torch.load( cp, map_location=torch.device("cpu"))["model_state_dict"] states.append(state_dict) state_dict = average_weights(states) model.load_state_dict(state_dict) del state_dict return model
def __init__(self, architecture="Unet", encoder="resnet34", depth=5, in_channels=3, classes=2, activation="softmax"): super(SegmentationModels, self).__init__() self.architecture = architecture self.encoder = encoder self.depth = depth self.in_channels = in_channels self.classes = classes self.activation = activation # define model _ARCHITECTURES = [ "Unet", "UnetPlusPlus", "Linknet", "MAnet", "FPN", "PSPNet", "PAN", "DeepLabV3", "DeepLabV3Plus" ] assert self.architecture in _ARCHITECTURES, "architecture=={0}, actual '{1}'".format( _ARCHITECTURES, self.architecture) if self.architecture == "Unet": self.model = smp.Unet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "UnetPlusPlus": self.model = smp.UnetPlusPlus( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "MAnet": self.model = smp.MAnet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "Linknet": self.model = smp.Linknet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "FPN": self.model = smp.FPN( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "PSPNet": self.model = smp.PSPNet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "PAN": self.model = smp.PAN( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "DeepLabV3": self.model = smp.DeepLabV3( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "DeepLabV3Plus": self.model = smp.DeepLabV3Plus( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth
classes=8, activation=args.activation, ), "deeplabv3plus": smp.DeepLabV3Plus( encoder_name=args.encoder, encoder_weights=args.weight, classes=8, activation=args.activation, aux_params=aux_params_dict, ), "pan": smp.PAN( encoder_name=args.encoder, encoder_weights=args.weight, classes=8, activation=args.activation, aux_params=aux_params_dict, ), } def save_checkpoint(state, filename): torch.save(state, filename) def main(): data_dir = Path(path_dict["train_data"]) # images_dir = data_dir.joinpath("image") images_dir = Path(config["train"]["image_dir"]) logger.info(f"Loading images from {images_dir}")
def main(): global args args = parse_args() print(args) torch.backends.cudnn.benchmark = True if args.deterministic: set_seed(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, 'Amp requires cudnn backend to be enabled.' # create model if args.cls: print('With classification') else: print('Without classification') if args.smodel == 'Unet': model = smp.Unet(encoder_name=args.encoder, encoder_weights='imagenet' if 'dpn92' not in args.encoder else 'imagenet+5k', classes=args.n_classes, in_channels=args.in_channels, decoder_attention_type=args.attention_type, activation=None) elif args.smodel == 'FPN': model = smp.FPN(encoder_name=args.encoder, encoder_weights='imagenet' if 'dpn92' not in args.encoder else 'imagenet+5k', classes=args.n_classes, in_channels=args.in_channels, activation=None) elif args.smodel == 'PAN': model = smp.PAN(encoder_name=args.encoder, encoder_weights='imagenet' if 'dpn92' not in args.encoder else 'imagenet+5k', classes=args.n_classes, in_channels=args.in_channels, activation=None) elif args.smodel == 'PSPNet': model = smp.PSPNet(encoder_name=args.encoder, encoder_weights='imagenet' if 'dpn92' not in args.encoder else 'imagenet+5k', classes=args.n_classes, in_channels=args.in_channels, encoder_depth=3, activation=None) else: raise if args.sync_bn: print('using apex synced BN') model = apex.parallel.convert_syncbn_model(model) model.cuda() # Scale learning rate based on global batch size print(f'lr={args.lr}, opt={args.opt}') if args.opt == 'adam': opt = apex.optimizers.FusedAdam( model.parameters( ), # add_weight_decay(model, args.weight_decay, ('bn', )), lr=args.lr, weight_decay=args.weight_decay, ) elif args.opt == 'sgd': opt = torch.optim.SGD( add_weight_decay(model, args.weight_decay, ('bn', )), # model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: raise # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. if args.fp16: model, opt = apex.amp.initialize( model, opt, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True) dice_loss = smp.utils.losses.DiceLoss(activation='sigmoid') bce_loss = nn.BCEWithLogitsLoss() if args.cls: def BCEBCE(logits, target): prediction_seg, prediction_cls = logits y_cls = (target.sum([2, 3]) > 0).float() return bce_loss(prediction_seg, target) + bce_loss( prediction_cls, y_cls) def symmetric_lovasz_fn(logits, target): prediction_seg, prediction_cls = logits y_cls = (target.sum([2, 3]) > 0).float() return symmetric_lovasz(prediction_seg, target) + bce_loss( prediction_cls, y_cls) else: if not args.use_softmax: def BCEBCE(logits, target): if args.loss == 'bce': return bce_loss(logits, target) elif args.loss == 'dice': return dice_loss(logits, target) elif args.loss == 'bce+dice': return bce_loss(logits, target) + dice_loss(logits, target) return bce_loss(logits, target) + dice_loss(logits, target) symmetric_lovasz_fn = symmetric_lovasz else: def BCEBCE(logits, target): return nn.CrossEntropyLoss()(logits, target[:, 0].long()) def symmetric_lovasz_fn(logits, target): return lovasz_softmax(torch.softmax(logits, dim=1), target[:, 0].long()) criterion = BCEBCE history = {k: {k_: [] for k_ in ['train', 'dev']} for k in ['loss']} best_score = 0 if not args.use_softmax: metrics = { 'j55': JaccardMicro(n_classes=args.n_classes, thresh=0.55, w3m=args.w3m), 'score': JaccardMicro(n_classes=args.n_classes, thresh=0.5, w3m=args.w3m), 'j45': JaccardMicro(n_classes=args.n_classes, thresh=0.45, w3m=args.w3m), 'd5': Dice(n_classes=args.n_classes, thresh=0.5, w3m=args.w3m), } else: metrics = { 'score': JaccardMicro(n_classes=args.n_classes, thresh=None, w3m=args.w3m), 'jaccard': Dice(n_classes=args.n_classes, thresh=None, w3m=args.w3m), } history.update({k: {v: [] for v in ['train', 'dev']} for k in metrics}) base_name = f'{args.encoder}_b{args.batch_size}_{args.opt}_lr{args.lr}_w3m{int(args.w3m)}_f{args.fold}' work_dir = Path(args.work_dir) / base_name if args.local_rank == 0 and not work_dir.exists(): work_dir.mkdir(parents=True) # Optionally load model from a checkpoint if args.load: def _load(): path_to_load = Path(args.load) if path_to_load.is_file(): print(f"=> loading model '{path_to_load}'") checkpoint = torch.load( path_to_load, map_location=lambda storage, loc: storage.cuda(args.gpu)) model.load_state_dict(checkpoint['state_dict']) if args.fp16 and checkpoint['amp'] is not None: apex.amp.load_state_dict(checkpoint['amp']) print(f"=> loaded model '{path_to_load}'") else: print(f"=> no model found at '{path_to_load}'") _load() # Optionally resume from a checkpoint if args.resume: # Use a local scope to avoid dangling references def _resume(): nonlocal history, best_score path_to_resume = Path(args.resume) if path_to_resume.is_file(): print(f"=> loading resume checkpoint '{path_to_resume}'") checkpoint = torch.load( path_to_resume, map_location=lambda storage, loc: storage.cuda(args.gpu)) args.start_epoch = checkpoint['epoch'] + 1 history = checkpoint['history'] best_score = checkpoint['best_score'] model.load_state_dict(checkpoint['state_dict']) opt.load_state_dict(checkpoint['opt_state_dict']) if args.fp16 and checkpoint['amp'] is not None: apex.amp.load_state_dict(checkpoint['amp']) print( f"=> resume from checkpoint '{path_to_resume}' (epoch {checkpoint['epoch']})" ) else: print(f"=> no checkpoint found at '{args.resume}'") _resume() history.update({ k: {v: [] for v in ['train', 'dev']} for k in metrics if k not in history }) scheduler = None if args.scheduler == 'cos': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=args.T_max, eta_min=max(args.lr * 1e-2, 1e-6), last_epoch=args.start_epoch if args.resume else -1) path_to_data = Path(args.data) train_gps, dev_gps = get_data_groups(path_to_data / args.csv, args) train_ds = CloudsDS(train_gps, root=path_to_data, transform=train_transform, w3m=args.w3m, use_softmax=args.use_softmax) dev_ds = CloudsDS(dev_gps, root=path_to_data, transform=dev_transform, w3m=args.w3m, use_softmax=args.use_softmax) train_sampler = None dev_sampler = None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_ds) dev_sampler = torch.utils.data.distributed.DistributedSampler(dev_ds) batch_size = args.batch_size num_workers = args.workers train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=train_sampler is None, sampler=train_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) dev_loader = torch.utils.data.DataLoader( dev_ds, batch_size=batch_size, # 20, 27 shuffle=False, sampler=dev_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) saver = lambda path: torch.save( { 'epoch': epoch, 'best_score': best_score, 'history': history, 'state_dict': model.state_dict(), 'opt_state_dict': opt.state_dict(), 'amp': apex.amp.state_dict() if args.fp16 else None, 'args': args, }, path) teachers = None if args.teachers is not None: teachers = [ torch.jit.load(str(p)).cuda().eval() for p in Path(args.teachers).rglob('*.pt') ] if args.distributed: for i in range(len(teachers)): teachers[i] = apex.parallel.DistributedDataParallel( teachers[i], delay_allreduce=True) print(f'#teachers: {len(teachers)}') for epoch in range(args.start_epoch, args.epochs + 1): if args.distributed: train_sampler.set_epoch(epoch) if epoch >= args.lovasz: criterion = symmetric_lovasz_fn for metric in metrics.values(): metric.clean() loss = epoch_step(train_loader, f'[ Training {epoch}/{args.epochs}.. ]', model=model, criterion=criterion, metrics=metrics, opt=opt, batch_accum=args.batch_accum, teachers=teachers) history['loss']['train'].append(loss) for k, metric in metrics.items(): history[k]['train'].append(metric.evaluate()) if not args.ft: with torch.no_grad(): for metric in metrics.values(): metric.clean() loss = epoch_step(dev_loader, f'[ Validating {epoch}/{args.epochs}.. ]', model=model, criterion=criterion, metrics=metrics, opt=None) history['loss']['dev'].append(loss) for k, metric in metrics.items(): history[k]['dev'].append(metric.evaluate()) else: history['loss']['dev'].append(loss) for k, metric in metrics.items(): history[k]['dev'].append(metric.evaluate()) if scheduler is not None: scheduler.step() if args.local_rank == 0: saver(work_dir / 'last.pth') if history['score']['dev'][-1] > best_score: best_score = history['score']['dev'][-1] saver(work_dir / 'best.pth') plot_hist(history, work_dir)
encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"], ), "deeplabv3plus": smp.DeepLabV3Plus( encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"], ), "pan": smp.PAN( encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"], ), } device = "cuda:0" if torch.cuda.is_available() else "cpu" model = arch_dict[checkpoint["arch"]] state_dict = OrderedDict() for key, value in checkpoint["state_dict"].items(): tmp = key[7:] state_dict[tmp] = value model.load_state_dict(state_dict) model.cuda() model.eval() preprocessing_fn = smp.encoders.get_preprocessing_fn( checkpoint["encoder"], checkpoint["encoder_weight"])
def __init__(self, architecture='Unet', encoder='resnet34', depth=5, in_channels=3, classes=2, activation='softmax'): super(SegmentationModels, self).__init__() self.architecture = architecture self.encoder = encoder self.depth = depth self.in_channels = in_channels self.classes = classes self.activation = activation # define model _ARCHITECTURES = ['Unet', 'Linknet', 'FPN', 'PSPNet', 'PAN', 'DeepLabV3', 'DeepLabV3Plus'] assert self.architecture in _ARCHITECTURES, 'architecture=={0}, actual \'{1}\''.format(_ARCHITECTURES, self.architecture) if self.architecture == 'Unet': self.model = smp.Unet(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'Linknet': self.model = smp.Linknet(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'FPN': self.model = smp.FPN(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'PSPNet': self.model = smp.PSPNet(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'PAN': self.model = smp.PAN(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'DeepLabV3': self.model = smp.DeepLabV3(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'DeepLabV3Plus': self.model = smp.DeepLabV3Plus(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth