def __init__(self, encoder_name="resnet34", encoder_depth=5, encoder_weights="imagenet", encoder_output_stride=16, decoder_channels=256, decoder_atrous_rates=(12, 24, 36), in_channels=3, classes=10, activation=None, upsampling=4, aux_params=None, ): super().__init__() self.model = smp.DeepLabV3Plus( encoder_name=encoder_name, encoder_depth=encoder_depth, encoder_weights=encoder_weights, encoder_output_stride=encoder_output_stride, decoder_channels=decoder_channels, decoder_atrous_rates=decoder_atrous_rates, in_channels=in_channels, classes=classes, activation=activation, upsampling=upsampling, aux_params=aux_params, )
def __init__(self, config): super(LesionModel, self).__init__() if config["A"] == "unet": self.model = smp.Unet( encoder_name=config[ "EN"], # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights=config["encoder_weights"], # use `imagenet` pre-trained weights for encoder initialization in_channels= 3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=config[ "OC"], # model output channels (number of classes in your dataset) ) elif config["A"] == "DeepLabV3Plus": self.model = smp.DeepLabV3Plus( encoder_name=config[ "EN"], # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights=config["encoder_weights"], # use `imagenet` pre-trained weights for encoder initialization in_channels= 3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=config[ "OC"], # model output channels (number of classes in your dataset) ) self.model_cfg = config self.loss_function = DiceLoss() self.save_hyperparameters() self.iou_function = IOU() self.checkpoint_path = ""
def createDeepLabv3Plus(outputchannels=1): """Creates the DeepLabV3+ model with the segmentation models library. Args: outputchannels (int, optional): Number of output channels. Returns: model: Returns the DeepLabv3+ model with the MobileNetV2 backbone. """ # aux_params=dict( # pooling='avg', # one of 'avg', 'max' # dropout=0.5, # dropout ratio, default is None # activation='sigmoid', # activation function, default is None # classes=1, # define number of output labels # ) model = smp.DeepLabV3Plus( encoder_name= "mobilenet_v2", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 in_channels= 3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=1 # model output channels (number of classes in your dataset) ) return model
def deeplab(encoder, encoder_weights): model = smp.DeepLabV3Plus(\ encoder_name=encoder,\ encoder_weights=encoder_weights,\ classes=1,\ activation='sigmoid') return model
def get_model_factory(self): """Get the corresponding model to self.architecture_name and self.encoder_name Design pattern: Factory""" if self.architecture_name == "Unet": return smp.Unet(encoder_name=self.encoder_name, in_channels=self.in_channels, classes=self.num_classes) elif self.architecture_name == "DeepLabV3+": return smp.DeepLabV3Plus(encoder_name=self.encoder_name, in_channels=self.in_channels, classes=self.num_classes) elif self.architecture_name == "SwinUnet": return models.swin_unet_2d( input_size=self.input_size, filter_num_begin=self.filter_num_begin, n_labels=self.n_labels, depth=self.depth, stack_num_down=self.stack_num_down, stack_num_up=self.stack_num_up, patch_size=self.patch_size, num_heads=self.num_heads, window_size=self.window_size, num_mlp=self.num_mlp, output_activation=self.output_activation, shift_window=self.shift_window) else: raise Exception( "Architecture {0:s} has not been implemented. Please select among [Unet, DeepLabV3+]" .format(self.architecture_name)) return None
def get_model(): model = smp.DeepLabV3Plus( encoder_name="resnext101_32x8d", encoder_weights='imagenet', in_channels=3, classes=1 ) return model
def build_model(configuration): model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+'] if configuration.Model.model_name.lower() == 'unet': return smp.Unet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes, decoder_attention_type=None, ) if configuration.Model.model_name.lower() == 'linknet': return smp.Linknet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pspnet': return smp.PSPNet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'fpn': return smp.FPN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pan': return smp.PAN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3+': return smp.DeepLabV3Plus( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3': return smp.DeepLabV3( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) raise KeyError(f'Model should be one of {model_list}')
def SkinLesionModel(model, pretrained=True): models_zoo = { 'deeplabv3plus': smp.DeepLabV3Plus('resnet101', encoder_weights='imagenet', aux_params=None), 'deeplabv3plus_resnext': smp.DeepLabV3Plus('resnext101_32x8d', encoder_weights='imagenet', aux_params=None), 'pspnet': smp.PSPNet('resnet101', encoder_weights='imagenet', aux_params=None), 'unetplusplus': smp.UnetPlusPlus('resnet101', encoder_weights='imagenet', aux_params=None), } net = models_zoo.get(model) if net is None: raise Warning('Wrong Net Name!!') return net
def get_model_factory(self): """Get the corresponding model to self.architecture_name and self.encoder_name Design pattern: Factory""" if self.architecture_name == "Unet": return smp.Unet(encoder_name = self.encoder_name, in_channels = self.in_channels, classes = self.num_classes) elif self.architecture_name == "DeepLabV3+": return smp.DeepLabV3Plus(encoder_name = self.encoder_name, in_channels = self.in_channels, classes = self.num_classes) else: raise Exception("Architecture {0:s} has not been implemented. Please select among [Unet, DeepLabV3+]".format(self.architecture_name)) return None
def __init__(self): super().__init__() # Resnet config aux_params = dict( pooling='max', # one of 'avg', 'max' dropout=0.5, # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=1, # define number of output labels ) self.model = smp.DeepLabV3Plus(encoder_name="efficientnet-b6", encoder_weights="imagenet", in_channels=3, classes=1, aux_params=aux_params)
def __init__(self): super().__init__() pretrained_model = smp.DeepLabV3Plus(encoder_name='resnet152', encoder_depth=5, encoder_weights='imagenet') self.model = pretrained_model self.sigmoid = nn.Sigmoid() self.iou_loss = IoUBCELoss() self.train_acc = pl.metrics.Accuracy() self.train_f1 = pl.metrics.F1() self.train_f2 = pl.metrics.FBeta(num_classes=2, beta=2) self.val_acc = pl.metrics.Accuracy() self.val_f1 = pl.metrics.F1() self.val_f2 = pl.metrics.FBeta(num_classes=2, beta=2)
def create_smp_model(arch, **kwargs): 'Create segmentation_models_pytorch model' assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}' if arch == "Unet": model = smp.Unet(**kwargs) elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs) elif arch == "MAnet": model = smp.MAnet(**kwargs) elif arch == "FPN": model = smp.FPN(**kwargs) elif arch == "PAN": model = smp.PAN(**kwargs) elif arch == "PSPNet": model = smp.PSPNet(**kwargs) elif arch == "Linknet": model = smp.Linknet(**kwargs) elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs) elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs) else: raise NotImplementedError setattr(model, 'kwargs', kwargs) return model
def __init__(self, num_classes=12, encoder="resnext101_32x8d", pretrain_weight="imagenet", decoder="DeepLabV3Plus"): super(smpModel, self).__init__() if (decoder == "DeepLabV3Plus"): self.backbone = smp.DeepLabV3Plus(encoder_name=encoder, encoder_weights=pretrain_weight, in_channels=3, classes=num_classes) elif (decoder == "DeepLabV3"): self.backbone = smp.DeepLabV3(encoder_name=encoder, encoder_weights=pretrain_weight, in_channels=3, classes=num_classes) elif (decoder == "UnetPlusPlus"): self.backbone = smp.UnetPlusPlus(encoder_name=encoder, encoder_weights=pretrain_weight, in_channels=3, classes=num_classes)
def get_model(num_classes, encoder='resnet50', encoder_weight='imagenet', activation='sigmoid'): ''' return model and optional preprocessing function ''' # create segmentation model with pretrained encoder model = smp.DeepLabV3Plus( encoder_name=encoder, encoder_weights=None, classes=num_classes, activation=activation, ) if encoder_weight is None: preprocessing_fn = no_pretrain_precessing else: preprocessing_fn = smp.encoders.get_preprocessing_fn( encoder, encoder_weight) return model, preprocessing_fn
def choose_network_architecture(self): # UNet if self.cfg.ARCH == 'UNet': print(f"===> Network Architecture: {self.cfg.ARCH}") # create segmentation model with pretrained encoder Unet = smp.Unet( encoder_name=self.cfg.ENCODER, encoder_weights=self.cfg.ENCODER_WEIGHTS, classes=len(self.cfg.CLASSES), activation=self.cfg.ACTIVATION, ) return Unet # UNet if self.cfg.ARCH == 'DeepLabV3+': print(f"===> Network Architecture: {self.cfg.ARCH}") # create segmentation model with pretrained encoder Unet = smp.DeepLabV3Plus( encoder_name=self.cfg.ENCODER, encoder_weights=self.cfg.ENCODER_WEIGHTS, classes=len(self.cfg.CLASSES), activation=self.cfg.ACTIVATION, # in_channels = 1 ) return Unet # FCN if self.cfg.ARCH == 'FCN': print(f"===> Network Architecture: {self.cfg.ARCH}") if self.cfg.ENCODER == 'resnet50': fcn = models.segmentation.fcn_resnet50(pretrained=True) if self.cfg.ENCODER == 'resnet101': fcn = models.segmentation.fcn_resnet101(pretrained=True) # FCNHead = models.segmentation.fcn.FCNHead UnetSegHead = smp.base.heads.SegmentationHead fcn.classifier[4] = UnetSegHead( 512, 1, kernel_size=1, activation='sigmoid') # FCNHead(2048, 1) return fcn
def __init__(self): super(model1, self).__init__() # model = smp.Unet() self.deeplab_model = smp.DeepLabV3Plus(encoder_name='resnet101') # encoder self.deeplab_encoder = self.deeplab_model.encoder # self attention generation self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.mlp1a = nn.Linear(2048, 4096) self.mlp2a = nn.Linear(4096, 2048) self.mlp1b = nn.Linear(256, 4096) self.mlp2b = nn.Linear(4096, 256) self.upsample = nn.Upsample(16) #.cuda(0) # decoder self.deeplab_decoder = self.deeplab_model.decoder self.deeplab_decoder = self.deeplab_decoder #.cuda(1) self.deeplab_seghead = self.deeplab_model.segmentation_head
def test_1(model_path, output_dir, test_loader, addNDVI): in_channels = 4 if(addNDVI): in_channels += 1 # model = smp.UnetPlusPlus( # encoder_name="resnet101", # encoder_weights="imagenet", # in_channels=4, # classes=10, # ) model = smp.DeepLabV3Plus( encoder_name="resnet101", encoder_weights="imagenet", in_channels=4, classes=10, ) # 如果模型是SWA if("swa" in model_path): model = AveragedModel(model) model.to(DEVICE); model.load_state_dict(torch.load(model_path)) model.eval() for image, image_stretch, image_path, ndvi in test_loader: with torch.no_grad(): image = image.cuda() image_stretch = image_stretch.cuda() output1 = model(image).cpu().data.numpy() output2 = model(image_stretch).cpu().data.numpy() output = (output1 + output2) / 2.0 for i in range(output.shape[0]): pred = output[i] pred = np.argmax(pred, axis = 0) + 1 pred = np.uint8(pred) save_path = os.path.join(output_dir, image_path[i][-10:].replace('.tif', '.png')) print(save_path) cv2.imwrite(save_path, pred)
import torch import numpy as np import segmentation_models_pytorch as smp device = torch.device("cuda") model = smp.DeepLabV3Plus( encoder_name="efficientnet-b7", encoder_weights=None, encoder_depth=5, in_channels=3, classes=2, ).to(device) model.eval() print(model) a = np.random.randn(1, 3, 256, 256).astype(np.float32) a = torch.from_numpy(a).to(device) b = model(a) print(b.shape)
torch.save({'model_dict': model.state_dict()}, checkpoint_save_path) print('model saved') m = 0 else: m += 1 if m >= max_patient: torch.save({'model_dict': model.state_dict()}, checkpoint_save_path) return return # In[ ]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") lr = 0.001 model = smp.DeepLabV3Plus(encoder_name='resnet101', aux_params={'classes': num_class}) model = model.to(device) opt = Adam(model.parameters(), lr=lr) bce_loss = nn.BCEWithLogitsLoss() ce_loss = nn.CrossEntropyLoss() train(model, opt, bce_loss, ce_loss, device, a1_class_path, max_patient=5) print('A1+class training completed.')
data, target = sam[0].to(device=device,dtype=torch.float), sam[1].to(device=device,dtype=torch.float) out = model(data) loss = bce_loss(out,target) * 100.0 test_loss.append(loss.item()) mask.append(target.cpu().detach().numpy()) pred.append(out.cpu().detach().numpy()) avg_test_loss = sum(test_loss) / len(test_loss) print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4)) return mask,pred # In[ ]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") optimal_model = smp.DeepLabV3Plus(encoder_name='resnet101').to(device) checkpoint = torch.load(a1_path) optimal_model.load_state_dict(checkpoint['model_dict']) optimal_model.eval() bce_loss = nn.BCEWithLogitsLoss() mask,pred = test(optimal_model,test_loader) gt_map = np.concatenate(mask, axis=0) pred_map = np.concatenate(pred, axis=0) # In[ ]: pred_map_sig = nn.Sigmoid()(torch.tensor(pred_map)) threshold = 0.5
folder_path, file_name = os.path.split(save_path) # if is not ospath.exists(folder_path) os.makedirs(folder_path, exist_ok=True) # model = smp.FPN( # encoder_name=ENCODER, # encoder_weights=ENCODER_WEIGHTS, # classes=len(CLASSES), # activation=ACTIVATION, # ) model = smp.DeepLabV3Plus( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, in_channels=4, ) # model = smp.Unet( # encoder_name=ENCODER, # encoder_weights=ENCODER_WEIGHTS, # classes=len(CLASSES), # activation=ACTIVATION, # in_channels = 4, # ) ##model = torch.load('/media/xingshi2/data/matting/Unet3/vgg_current_model.pth') if os.path.exists(init_pre_path): model1 = torch.load(init_pre_path)
batch_size=1, shuffle=False, num_workers=4) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") loss = sm.utils.losses.DiceLoss() metrics = [ sm.utils.metrics.Accuracy(threshold=0.5), ] model = sm.DeepLabV3Plus('resnet34', classes=1, decoder_channels=512, in_channels=3, activation='sigmoid', encoder_weights='imagenet') optimizer = torch.optim.Adam([ dict(params=model.parameters(), lr=0.0001), ]) train_epoch = sm.utils.train.TrainEpoch( model, loss=loss, metrics=metrics, optimizer=optimizer, device=device, ) valid_epoch = sm.utils.train.ValidEpoch(
valid_dataset = CloudDataset(path = path, df=train, datatype='valid', img_ids=valid_ids, transforms = utils.get_validation_augmentation(), preprocessing = utils.get_preprocessing(preprocessing_fn)) train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers) valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=num_workers) loaders = { "train": train_loader, "valid": valid_loader } print("setting for training...") ACTIVATION = None model = smp.DeepLabV3Plus( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=ACTIVATION, ) wandb.watch(model) num_epochs = 50 logdir = "./logs/segmentation" # model, criterion, optimizer optimizer = torch.optim.Adam([ {'params': model.decoder.parameters(), 'lr': 1e-2}, {'params': model.encoder.parameters(), 'lr': 1e-3}, ]) scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2) loss = BCEDiceLoss() # or DiceLoss() metrics = [
def __init__(self, num_classes=12): super(SMP_DeepLabV3Plus_efficientnet_b1, self).__init__() self.seg_model = smp.DeepLabV3Plus(encoder_name="efficientnet-b1", encoder_weights="imagenet", in_channels=3, classes=12)
def __init__(self, num_classes=12): super(SMP_DeepLabV3Plus_timm_resnest101e, self).__init__() self.seg_model = smp.DeepLabV3Plus(encoder_name="timm-resnest101e", encoder_weights="imagenet", in_channels=3, classes=12)
print('model saved') m=0 else: m+=1 if m >= max_patient: return return # In[11]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = smp.DeepLabV3Plus(encoder_name='resnet101') # In[ ]: lr=0.001 model = model.to(device) opt = Adam(model.parameters(),lr=lr) bce_loss = nn.BCEWithLogitsLoss() train(model,opt,bce_loss,device, a1_path) print('A1 training completed.')
def __init__(self, num_classes=12): super(SMP_DeepLabV3Plus_se_resnext101_32x4d, self).__init__() self.seg_model = smp.DeepLabV3Plus(encoder_name="se_resnext101_32x4d", encoder_weights="imagenet", in_channels=3, classes=12)
def test_calculate_metric(iter_nums): if args.net == 'unet': # timm-efficientnet performs slightly worse. if not args.vis_mode: backbone_type = re.sub("^eff", "efficientnet", args.backbone_type) net = smp.Unet(backbone_type, classes=args.num_classes, encoder_weights='imagenet') else: net = VanillaUNet(n_channels=3, num_classes=args.num_classes) elif args.net == 'unet-scratch': # net = UNet(num_classes=args.num_classes) net = VanillaUNet(n_channels=3, num_classes=args.num_classes, use_polyformer=args.polyformer_mode, num_modes=args.num_modes) elif args.net == 'nestedunet': net = NestedUNet(num_classes=args.num_classes) elif args.net == 'unet3plus': net = UNet_3Plus(num_classes=args.num_classes) elif args.net == 'pranet': net = PraNet(num_classes=args.num_classes - 1) elif args.net == 'attunet': net = AttU_Net(output_ch=args.num_classes) elif args.net == 'attr2unet': net = R2AttU_Net(output_ch=args.num_classes) elif args.net == 'dunet': net = DeformableUNet(n_channels=3, n_classes=args.num_classes) elif args.net == 'setr': # Install mmcv first: # pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cu{Your CUDA Version}/torch{Your Pytorch Version}/index.html # E.g.: pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.7.1/index.html from mmcv.utils import Config sys.path.append("networks/setr") from mmseg.models import build_segmentor task2config = { 'refuge': 'SETR_PUP_288x288_10k_refuge_context_bs_4.py', 'polyp': 'SETR_PUP_320x320_10k_polyp_context_bs_4.py' } setr_cfg = Config.fromfile("networks/setr/configs/SETR/{}".format(task2config[args.task_name])) net = build_segmentor(setr_cfg.model, train_cfg=setr_cfg.train_cfg, test_cfg=setr_cfg.test_cfg) # By default, net() calls forward_train(), which receives extra parameters, and returns losses. # net.forward_dummy() receives/returns the traditional input/output. # Relevant file: mmseg/models/segmentors/encoder_decoder.py net.forward = net.forward_dummy elif args.net == 'transunet': transunet_config = TransUNet_CONFIGS[args.backbone_type] transunet_config.n_classes = args.num_classes if args.backbone_type.find('R50') != -1: # The "patch" in TransUNet means grid-like patches of the input image. # The "patch" in our code means the whole input image after cropping/resizing (part of the augmentation) transunet_config.patches.grid = (int(args.patch_size[0] / transunet_config.patches.size[0]), int(args.patch_size[1] / transunet_config.patches.size[1])) net = TransUNet(transunet_config, img_size=args.patch_size, num_classes=args.num_classes) elif args.net.startswith('deeplab'): use_smp_deeplab = args.net.endswith('smp') if use_smp_deeplab: backbone_type = re.sub("^eff", "efficientnet", args.backbone_type) net = smp.DeepLabV3Plus(backbone_type, classes=args.num_classes, encoder_weights='imagenet') else: model_name = args.net + "_" + args.backbone_type model_map = { 'deeplabv3_resnet50': deeplab.deeplabv3_resnet50, 'deeplabv3plus_resnet50': deeplab.deeplabv3plus_resnet50, 'deeplabv3_resnet101': deeplab.deeplabv3_resnet101, 'deeplabv3plus_resnet101': deeplab.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': deeplab.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': deeplab.deeplabv3plus_mobilenet } net = model_map[model_name](num_classes=args.num_classes, output_stride=8) elif args.net == 'nnunet': from nnunet.network_architecture.initialization import InitWeights_He from nnunet.network_architecture.generic_UNet import Generic_UNet net = Generic_UNet( input_channels=3, base_num_features=32, num_classes=args.num_classes, num_pool=7, num_conv_per_stage=2, feat_map_mul_on_downscale=2, norm_op=nn.InstanceNorm2d, norm_op_kwargs={'eps': 1e-05, 'affine': True}, dropout_op_kwargs={'p': 0, 'inplace': True}, nonlin_kwargs={'negative_slope': 0.01, 'inplace': True}, final_nonlin=(lambda x: x), weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=[[2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]], conv_kernel_sizes=([[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]), upscale_logits=False, convolutional_pooling=True, convolutional_upsampling=True, ) net.inference_apply_nonlin = (lambda x: F.softmax(x, 1)) elif args.net == 'segtran': get_default(args, 'num_modes', default_settings, -1, [args.net, 'num_modes', args.in_fpn_layers]) set_segtran2d_config(args) print(args) net = Segtran2d(config) else: breakpoint() net.cuda() net.eval() if args.robust_ref_cp_path: refnet = copy.deepcopy(net) print("Reference network created") load_model(refnet, args, args.robust_ref_cp_path) else: refnet = None # Currently colormap is used only for OCT task. colormap = get_seg_colormap(args.num_classes, return_torch=True).cuda() # prepred: pre-prediction. postpred: post-prediction. task2mask_prepred = { 'refuge': refuge_map_mask, 'polyp': polyp_map_mask, 'oct': partial(index_to_onehot, num_classes=args.num_classes) } task2mask_postpred = { 'refuge': refuge_inv_map_mask, 'polyp': polyp_inv_map_mask, 'oct': partial(onehot_inv_map, colormap=colormap) } mask_prepred_mapping_func = task2mask_prepred[args.task_name] mask_postpred_mapping_funcs = [ task2mask_postpred[args.task_name] ] if args.do_remove_frag: remove_frag = lambda segmap: remove_fragmentary_segs(segmap, 255) mask_postpred_mapping_funcs.append(remove_frag) if not args.checkpoint_dir: if args.vis_mode is not None: visualize_model(net, args.vis_mode, args.vis_layers, args.patch_size, db_test) return if args.eval_robustness: eval_robustness(args, net, refnet, testloader, mask_prepred_mapping_func) return allcls_avg_metric = None all_results = np.zeros((args.num_classes, len(iter_nums))) for iter_idx, iter_num in enumerate(iter_nums): if args.checkpoint_dir: checkpoint_path = os.path.join(args.checkpoint_dir, 'iter_' + str(iter_num) + '.pth') load_model(net, args, checkpoint_path) if args.vis_mode is not None: visualize_model(net, args.vis_mode, args.vis_layers, args.patch_size, db_test) continue if args.eval_robustness: eval_robustness(args, net, refnet, testloader, mask_prepred_mapping_func) continue save_results = args.save_results and (not args.test_interp) if save_results: test_save_paths = [] test_save_dirs = [] test_save_dir_tmpl = "%s-%s-%s-%d" %(args.net, args.job_name, timestamp, iter_num) for suffix in ("-soft", "-%.1f" %args.mask_thres): test_save_dir = test_save_dir_tmpl + suffix test_save_path = "../prediction/%s" %(test_save_dir) if not os.path.exists(test_save_path): os.makedirs(test_save_path) test_save_dirs.append(test_save_dir) test_save_paths.append(test_save_path) else: test_save_paths = None test_save_dirs = None if args.save_features_img_count > 0: args.save_features_file_path = "%s-%s-feat-%s.pth" %(args.net, args.job_name, timestamp) else: args.save_features_file_path = None allcls_avg_metric, allcls_metric_count = \ test_all_cases(net, testloader, task_name=args.task_name, num_classes=args.num_classes, mask_thres=args.mask_thres, model_type=args.net, orig_input_size=args.orig_input_size, patch_size=args.patch_size, stride=(args.orig_input_size[0] // 2, args.orig_input_size[1] // 2), test_save_paths=test_save_paths, out_origsize=args.out_origsize, mask_prepred_mapping_func=mask_prepred_mapping_func, mask_postpred_mapping_funcs=mask_postpred_mapping_funcs, reload_mask=args.reload_mask, test_interp=args.test_interp, save_features_img_count=args.save_features_img_count, save_features_file_path=args.save_features_file_path, verbose=args.verbose_output) print("Iter-%d scores on %d images:" %(iter_num, allcls_metric_count[0])) dice_sum = 0 for cls in range(1, args.num_classes): dice = allcls_avg_metric[cls-1] print('class %d: dice = %.3f' %(cls, dice)) dice_sum += dice all_results[cls, iter_idx] = dice avg_dice = dice_sum / (args.num_classes - 1) print("Average dice: %.3f" %avg_dice) if args.net == 'segtran': max_attn, avg_attn, clamp_count, call_count = \ [ segtran_shared.__dict__[v] for v in ('max_attn', 'avg_attn', 'clamp_count', 'call_count') ] print("max_attn={:.2f}, avg_attn={:.2f}, clamp_count={}, call_count={}".format( max_attn, avg_attn, clamp_count, call_count)) if save_results: FNULL = open(os.devnull, 'w') for pred_type, test_save_dir, test_save_path in zip(('soft', 'hard'), test_save_dirs, test_save_paths): do_tar = subprocess.run(["tar", "cvf", "%s.tar" %test_save_dir, test_save_dir], cwd="../prediction", stdout=FNULL, stderr=subprocess.STDOUT) # print(do_tar) print("{} tarball:\n{}.tar".format(pred_type, os.path.abspath(test_save_path))) np.set_printoptions(precision=3, suppress=True) print(all_results[1:]) return allcls_avg_metric
def __init__(self, num_classes=12): super(SMP_DeepLabV3Plus_xception, self).__init__() self.seg_model = smp.DeepLabV3Plus(encoder_name="xception", encoder_weights="imagenet", in_channels=3, classes=12)
def main(args, config): logger.info("-------------------- Hyperparameters --------------------") logger.info(f"Seed: {args.seed}") logger.info(f"Loss: {config.loss}") logger.info(f"Optimizer: {config.optimizer}") logger.info(f"Weight decay: {config.weight_decay}") logger.info(f"Epochs: {config.epochs}") logger.info(f"Batch size: {config.batch_size}") logger.info(f"Learning rate: {config.learning_rate}") logger.info("--------------------------------------------------\n") logger.info("--------------------------------------------------") if torch.cuda.is_available(): device = torch.device("cuda") logger.info(f'There are {torch.cuda.device_count()} GPU(s) available.') logger.info(f'We will use the GPU:{torch.cuda.get_device_name(0)}') else: device = torch.device("cpu") logger.info('No GPU available, using the CPU instead.') logger.info("--------------------------------------------------\n") """ Dataset """ # Augmentation train_transform, val_transform, _ = get_transforms(config.aug) # Train dataset train_dataset = CustomDataLoader(data_dir=train_data_path, mode='train', transform=train_transform) # Validation dataset val_dataset = CustomDataLoader(data_dir=val_data_path, mode='val', transform=val_transform) # DataLoader train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=1, collate_fn=collate_fn) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn) """ Model """ if args.api == 'smp': if config.model == 'DeepLabV3Plus': model = smp.DeepLabV3Plus(encoder_name=config.enc, encoder_weights=config.enc_weights, classes=12) else: if config.model == 'DeepLabV3EffiB7Timm': model = DeepLabV3EffiB7Timm(n_classes=12, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]) logger.info("-------------------- Model Test --------------------") x = torch.randn([2, 3, 512, 512]) logger.info(f"Input shape : {x.shape}") out = model(x).to(device) logger.info(f"Output shape: {out.shape}") logger.info("--------------------------------------------------\n") model = model.to(device) # Loss function if config.loss == 'CE': criterion = nn.CrossEntropyLoss() elif config.loss == 'SoftCE': criterion = SoftCrossEntropyLoss(smooth_factor=config.smooth_factor) elif config.loss == 'Focal': criterion = FocalLoss('multiclass', gamma=config.focal_gamma) elif config.loss == 'DiceCE': pass elif config.loss == 'RMI': # criterion = RMILoss(num_classes=12, loss_weight_lambda=config.RMI_weight) criterion = [ SoftCrossEntropyLoss(smooth_factor=config.smooth_factor), FocalLoss('multiclass', gamma=config.focal_gamma), # RMILoss(num_classes=12, rmi_only=True) RMILoss(num_classes=12) ] # TODO: Split 및 getattr() 적용 elif config.loss == 'SoftCE+Focal+RMI': criterion = [ SoftCrossEntropyLoss(smooth_factor=config.smooth_factor), FocalLoss('multiclass', gamma=config.focal_gamma), # RMILoss(num_classes=12, rmi_only=True) RMILoss(num_classes=12) ] else: raise Exception('[ERROR] Invalid loss') # Optimizer learning_rate = config.lr_min if config.lr_scheduler == 'SGDR' else config.learning_rate if config.optimizer == 'Adam': optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=config.weight_decay) elif config.optimizer == 'AdamP': optimizer = AdamP(params=model.parameters(), lr=learning_rate, weight_decay=config.weight_decay) else: raise Exception('[ERROR] Invalid optimizer') # Learning rate scheduler if config.lr_scheduler == 'no': scheduler = None elif config.lr_scheduler == 'SGDR': scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=config.T, T_up=config.T_warmup, T_mult=config.T_mult, eta_max=config.lr_max, gamma=config.lr_max_decay) else: raise Exception('[ERROR] Invalid learning rate scheduler') """ Train the model """ train(epochs=config.epochs, model=model, data_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer, scheduler=scheduler, val_every=1, device=device)