def get_model(name, classification_head, model_weights_path=None): if name == 'unet34': return smp.Unet('resnet34', encoder_weights='imagenet') elif name == 'unet18': print('classification_head:', classification_head) if classification_head: aux_params = dict( pooling='max', # one of 'avg', 'max' dropout=0.1, # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=1, # define number of output labels ) return smp.Unet('resnet18', aux_params=aux_params, encoder_weights=None, encoder_depth=2, decoder_channels=(256, 128)) else: return smp.Unet('resnet18', encoder_weights='imagenet', encoder_depth=2, decoder_channels=(256, 128)) elif name == 'unet50': return smp.Unet('resnet50', encoder_weights='imagenet') elif name == 'unet101': return smp.Unet('resnet101', encoder_weights='imagenet') elif name == 'linknet34': return smp.Linknet('resnet34', encoder_weights='imagenet') elif name == 'linknet50': return smp.Linknet('resnet50', encoder_weights='imagenet') elif name == 'fpn34': return smp.FPN('resnet34', encoder_weights='imagenet') elif name == 'fpn50': return smp.FPN('resnet50', encoder_weights='imagenet') elif name == 'fpn101': return smp.FPN('resnet101', encoder_weights='imagenet') elif name == 'pspnet34': return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1) elif name == 'pspnet50': return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1) elif name == 'fpn50_season': from clearcut_research.pytorch import FPN_double_output return FPN_double_output('resnet50', encoder_weights='imagenet') elif name == 'fpn50_satellite': fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None) fpn_resnet50.encoder = get_satellite_pretrained_resnet( model_weights_path) return fpn_resnet50 elif name == 'fpn50_multiclass': return smp.FPN('resnet50', encoder_weights='imagenet', classes=3, activation='softmax') else: raise ValueError("Unknown network")
def get_basenet(basenet, backbone, encoder_weights, classes, decoder_channels, activation='sigmoid'): if basenet == 'fpn': return smp.FPN(backbone, encoder_weights=encoder_weights, classes=classes, activation=activation) elif basenet == 'psp': return smp.PSPNet(backbone, encoder_weights=encoder_weights, classes=classes, activation=activation) elif basenet == 'deeplabv3': return smp.DeepLabV3(backbone, encoder_weights=encoder_weights, classes=classes, activation=activation) return smp.Unet(backbone, encoder_weights=encoder_weights, encoder_depth=len(decoder_channels), classes=classes, decoder_channels=decoder_channels, activation=activation)
def init(config): # ---- Model Initialization ---- if config["model"] == "UNet": model = smp.Unet( activation=None ) #UNet2D(n_channels=3, n_classes=1) # #UNet2D(n_channels=1, n_classes=1) #smp.Unet(activation=None) elif config["model"] == "PSPNet": model = smp.PSPNet(activation=None) elif config["model"] == "FPN": model = smp.FPN(activation=None) elif config["model"] == "Linknet": model = smp.Linknet(activation=None) else: raise Exception('Incorrect model name!') # ---- Loss Initialization ---- if config["mode"] == 'train': if config["loss"] == "DiceBCE": loss = LossBinaryDice(dice_weight=config["dice_weight"]) elif config["loss"] == "FocalTversky": loss = FocalTverskyLoss() elif config["loss"] == "Focal": loss = FocalLoss() elif config["loss"] == "Tversky": loss = TverskyLoss() else: raise Exception('Incorrect loss name!') return model, loss else: return model
def __init__(self, debug=False): super().__init__() self.PSPNet = smp.PSPNet(encoder_name='resnet34', encoder_weights='imagenet', classes=4, activation=None) self.debug = debug
def __init__(self, in_channels=24): super(Cloud2Cloud, self).__init__() self.cloudNet = smp.PSPNet(encoder_name='resnet34', classes=4, encoder_weights='imagenet') self.cloudNet.encoder.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False) self.metNet = smp.PSPNet(encoder_name='resnet18', encoder_weights='imagenet') self.cloud_encoder = self.cloudNet.encoder self.met_encoder = self.metNet.encoder self.cloud_decoder = self.cloudNet.decoder
def build_model(self): print("Using model: {}".format(self.model_type)) """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', pretrained=True, classes=self.output_ch) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=self.output_ch) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=self.output_ch) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) if torch.cuda.is_available(): self.unet = torch.nn.DataParallel(self.unet) self.criterion = self.criterion.cuda() self.criterion_stage2 = self.criterion_stage2.cuda() self.criterion_stage3 = self.criterion_stage3.cuda() self.unet.to(self.device)
def __init__(self, model_name, in_channels=3, out_channels=1): super(SmpModel16, self).__init__() aux_params = dict( pooling='max', # one of 'avg', 'max' dropout=0.5, # dropout ratio, default is None activation='softmax', # activation function, default is None classes=out_channels, # define number of output labels ) if 'unet' in model_name: self.model = smp.Unet( encoder_name= "resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_depth=3, encoder_weights= "imagenet", # use `imagenet` pretrained weights for encoder initialization decoder_channels=[128, 64, 32], # [256, 128, 64, 32] in_channels= in_channels, # model input channels (1 for grayscale images, 3 for RGB, etc.) classes= out_channels, # model output channels (number of classes in your dataset) aux_params=aux_params, ) elif 'uplus' in model_name: self.model = smp.UnetPlusPlus( encoder_name= "resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_depth=3, encoder_weights= "imagenet", # use `imagenet` pretrained weights for encoder initialization decoder_channels=[256, 128, 64], in_channels= in_channels, # model input channels (1 for grayscale images, 3 for RGB, etc.) classes= out_channels, # model output channels (number of classes in your dataset) aux_params=aux_params, ) else: self.model = smp.PSPNet( encoder_name= "resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 # encoder_depth=4, encoder_weights= "imagenet", # use `imagenet` pretrained weights for encoder initialization # decoder_channels=[256, 128, 64, 32], in_channels= in_channels, # model input channels (1 for grayscale images, 3 for RGB, etc.) classes= out_channels, # model output channels (number of classes in your dataset) aux_params=aux_params, ) self.down_layer = DownsizeBlock(out_channels, downsize_mode=2)
def prepare_model(backbone='se_resnet50', weight='None', checkpoint_path=None): ''' change channel 3 to channel 1 original as the following # first layer model.encoder.layer0.conv1 = nn.Conv2d(3, 64,...) # last layer model.decoder.final_conv = nn.Conv2d(512, 3,...) ''' if weight == 'None': model = smp.PSPNet(backbone, encoder_weights=None, classes=1, activation=None) elif weight == 'imagenet': model = smp.PSPNet(backbone, encoder_weights='imagenet', classes=1, activation=None) elif weight == 'pretrained': model = smp.PSPNet(backbone, encoder_weights=None, classes=1, activation=None) model.encoder.layer0.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) model.to(torch.device("cuda:0")) state = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) model.load_state_dict(state["state_dict"]) model.encoder.layer0.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) return model
def resnet50_PSPNet_noclassification(**kwargs): model = smp.PSPNet('resnet50', in_channels=in_channels, classes=classes, activation=activation, **kwargs) print("Just segmentation Model args:") print("in_channels:%d,classes:%d,activation:%s" % (in_channels, classes, activation)) print("kwargs", kwargs) return model
def PSP(self, img_ch, output_ch): return smp.PSPNet(encoder_name=self.encoder, encoder_weights=self.en_weights, encoder_depth=3, psp_out_channels=512, psp_use_batchnorm=False, psp_dropout=0.2, in_channels=img_ch, classes=output_ch, activation=None, upsampling=8, aux_params=None)
def resnet34_psp(num_classes): ENCODER = 'resnet34' ENCODER_WEIGHTS = 'imagenet' model = smp.PSPNet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=num_classes, activation=None, ) return model
def build_model(configuration): model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+'] if configuration.Model.model_name.lower() == 'unet': return smp.Unet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes, decoder_attention_type=None, ) if configuration.Model.model_name.lower() == 'linknet': return smp.Linknet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pspnet': return smp.PSPNet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'fpn': return smp.FPN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pan': return smp.PAN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3+': return smp.DeepLabV3Plus( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3': return smp.DeepLabV3( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) raise KeyError(f'Model should be one of {model_list}')
def create_segmentation_models(encoder, arch, num_classes=4, encoder_weights=None, activation=None): ''' segmentation_models_pytorch https://github.com/qubvel/segmentation_models.pytorch has following architectures: - Unet - Linknet - FPN - PSPNet encoders: A lot! see the above github page. Deeplabv3+ https://github.com/jfzhang95/pytorch-deeplab-xception has for encoders: - resnet (resnet101) - mobilenet - xception - drn ''' if arch == "Unet": return smp.Unet(encoder, encoder_weights=encoder_weights, classes=num_classes, activation=activation) elif arch == "Linknet": return smp.Linknet(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "FPN": return smp.FPN(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "PSPNet": return smp.PSPNet(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "deeplabv3plus": if deeplabv3plus_PATH in os.environ: sys.path.append(os.environ[deeplabv3plus_PATH]) from modeling.deeplab import DeepLab return DeepLab(encoder, num_classes=4) else: raise ValueError('Set deeplabv3plus path by environment variable.') else: raise ValueError( 'arch {} is not found, set the correct arch'.format(arch)) sys.exit()
def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'unet_resnet34': # self.unet = Unet(backbone_name='resnet34', classes=1) self.unet = smp.Unet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet50': self.unet = smp.Unet('resnet50', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_se_resnext50_32x4d': self.unet = smp.Unet('se_resnext50_32x4d', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_densenet121': self.unet = smp.Unet('densenet121', encoder_weights='imagenet', activation=None) elif self.model_type == 'unet_resnet34_t': self.unet = Unet_t('resnet34', encoder_weights='imagenet', activation=None, use_ConvTranspose2d=True) elif self.model_type == 'unet_resnet34_oct': self.unet = OctaveUnet('resnet34', encoder_weights='imagenet', activation=None) elif self.model_type == 'pspnet_resnet34': self.unet = smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1, activation=None) elif self.model_type == 'linknet': self.unet = LinkNet34(num_classes=1) elif self.model_type == 'deeplabv3plus': self.unet = DeepLabV3Plus(model_backbone='res50_atrous', num_classes=1) # self.unet = DeepLabV3Plus(num_classes=1) # print('build model done!') self.unet.to(self.device)
def get_model(name='fpn50', model_weights_path=None): if name == 'unet34': return smp.Unet('resnet34', encoder_weights='imagenet') elif name == 'unet50': return smp.Unet('resnet50', encoder_weights='imagenet') elif name == 'unet101': return smp.Unet('resnet101', encoder_weights='imagenet') elif name == 'linknet34': return smp.Linknet('resnet34', encoder_weights='imagenet') elif name == 'linknet50': return smp.Linknet('resnet50', encoder_weights='imagenet') elif name == 'fpn34': return smp.FPN('resnet34', encoder_weights='imagenet') elif name == 'fpn50': return smp.FPN('resnet50', encoder_weights='imagenet') elif name == 'fpn101': return smp.FPN('resnet101', encoder_weights='imagenet') elif name == 'pspnet34': return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1) elif name == 'pspnet50': return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1) elif name == 'fpn50_season': from clearcut_research.pytorch import FPN_double_output return FPN_double_output('resnet50', encoder_weights='imagenet') elif name == 'fpn50_satellite': fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None) fpn_resnet50.encoder = get_satellite_pretrained_resnet( model_weights_path) return fpn_resnet50 elif name == 'fpn50_multiclass': return smp.FPN('resnet50', encoder_weights='imagenet', classes=3, activation='softmax') else: raise ValueError("Unknown network")
def inference(config_file, model_file, input_data, output): debug = False config = json.load(open(f"{config_file}")) encoder = config["arch"]["args"]["encoder"] encoder_weights = config["arch"]["args"]["encoder_weights"] preprocessing_fn = smp.encoders.get_preprocessing_fn( encoder, encoder_weights) model = smp.PSPNet( encoder_name=encoder, encoder_weights=encoder_weights, classes=1, activation=config["training"]["activation"], ) model.load_state_dict( torch.load(f"{model_file}", map_location=torch.device(DEVICE))["model_state_dict"]) model.to(DEVICE) infer_dataset = WeedDataset( [input_data], weed_label=config["data"]["weed_label"], augmentation=aug.get_validation_augmentations(config["data"]["aug"]), preprocessing=aug.get_preprocessing(preprocessing_fn), ) for i in range(len(infer_dataset)): image, gt_mask, id = infer_dataset.get_img_and_props(i) gt_mask = gt_mask.squeeze() x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0) pr_mask = model.predict(x_tensor) pr_mask = (pr_mask.squeeze().cpu().numpy().round()) if debug: visualize(image=image.squeeze().swapaxes(0, 1).swapaxes(1, 2), ground_truth_mask=gt_mask, predicted_mask=pr_mask) if id["rotate"]: pr_mask = cv2.rotate(pr_mask, cv2.ROTATE_90_COUNTERCLOCKWISE) pr_mask = cv2.resize(pr_mask, (id["img_height"], id["img_width"]), interpolation=cv2.INTER_LINEAR) save_mask(pr_mask, output, os.path.basename(id["img_id"]))
def get_model(num_classes, model_name): if model_name == "UNet": print("using UNet") model = smp.Unet(encoder_name='resnet50', classes=num_classes, activation='sigmoid') if args.num_channels >3: weight = model.encoder.conv1.weight.clone() model.encoder.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) with torch.no_grad(): print("using 4c") model.encoder.conv1.weight[:, :3] = weight model.encoder.conv1.weight[:, 3] = model.encoder.conv1.weight[:, 0] return model elif model_name == "PSPNet": print("using PSPNet") model = smp.PSPNet(encoder_name="resnet50", classes=num_classes, activation='softmax') if args.num_channels > 3: weight = model.encoder.conv1.weight.clone() model.encoder.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) with torch.no_grad(): print("using 4c") model.encoder.conv1.weight[:, :3] = weight model.encoder.conv1.weight[:, 3] = model.encoder.conv1.weight[:, 0] return model elif model_name == "FPN": print("using FPN") model = smp.FPN(encoder_name='resnet50', classes=num_classes) if args.num_channels > 3: weight = model.encoder.conv1.weight.clone() model.encoder.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) with torch.no_grad(): print("using 4c") model.encoder.conv1.weight[:, :3] = weight model.encoder.conv1.weight[:, 3] = model.encoder.conv1.weight[:, 0] return model elif model_name == "AlbuNet": print("using AlbuNet") model = AlbuNet(pretrained=True, num_classes=num_classes) return model elif model_name == "YpUnet": print("using YpUnet_ASPP") model = UNet(num_classes=num_classes) return model else: print("error in model") return None
def create_smp_model(arch, **kwargs): 'Create segmentation_models_pytorch model' assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}' if arch == "Unet": model = smp.Unet(**kwargs) elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs) elif arch == "MAnet": model = smp.MAnet(**kwargs) elif arch == "FPN": model = smp.FPN(**kwargs) elif arch == "PAN": model = smp.PAN(**kwargs) elif arch == "PSPNet": model = smp.PSPNet(**kwargs) elif arch == "Linknet": model = smp.Linknet(**kwargs) elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs) elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs) else: raise NotImplementedError setattr(model, 'kwargs', kwargs) return model
def SkinLesionModel(model, pretrained=True): models_zoo = { 'deeplabv3plus': smp.DeepLabV3Plus('resnet101', encoder_weights='imagenet', aux_params=None), 'deeplabv3plus_resnext': smp.DeepLabV3Plus('resnext101_32x8d', encoder_weights='imagenet', aux_params=None), 'pspnet': smp.PSPNet('resnet101', encoder_weights='imagenet', aux_params=None), 'unetplusplus': smp.UnetPlusPlus('resnet101', encoder_weights='imagenet', aux_params=None), } net = models_zoo.get(model) if net is None: raise Warning('Wrong Net Name!!') return net
def get_model(encoder='resnet18', type='unet', encoder_weights='imagenet', classes=4): # My own simple wrapper around smp if type == 'unet': model = smp.Unet( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) elif type == 'fpn': model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) elif type == 'pspnet': model = smp.PSPNet( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) elif type == 'linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) else: raise "weird architecture" print(f"Training on {type} architecture with {encoder} encoder") preprocessing_fn = smp.encoders.get_preprocessing_fn( encoder, encoder_weights) return model, preprocessing_fn
def load_model(net, ENCODER, ENCODER_WEIGHTS, ACTIVATION): if net == "Unet": model = smp.Unet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=ACTIVATION, ) elif net == "FPN": model = smp.FPN( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=ACTIVATION, ) elif net == "PSPNet": model = smp.PSPNet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=ACTIVATION, ) return model
def __init__(self, encoder, encoder_weights, classes, activation, learning_rate=1e-3, **kwargs): super().__init__() self.save_hyperparameters() self.classes = classes if self.hparams.architecture == 'fpn': self.model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation, ) elif self.hparams.architecture == 'pan': self.model = smp.PAN( encoder_name=encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation, ) elif self.hparams.architecture == 'pspnet': self.model = smp.PSPNet( encoder_name=encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation, ) else: raise NameError('') self.loss = smp.utils.losses.DiceLoss()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--encoder', type=str, default='efficientnet-b0') parser.add_argument('--model', type=str, default='unet') parser.add_argument('--loc', type=str) parser.add_argument('--data_folder', type=str, default='../input/') parser.add_argument('--batch_size', type=int, default=2) parser.add_argument('--optimize', type=bool, default=False) parser.add_argument('--tta_pre', type=bool, default=False) parser.add_argument('--tta_post', type=bool, default=False) parser.add_argument('--merge', type=str, default='mean') parser.add_argument('--min_size', type=int, default=10000) parser.add_argument('--thresh', type=float, default=0.5) parser.add_argument('--name', type=str) args = parser.parse_args() encoder = args.encoder model = args.model loc = args.loc data_folder = args.data_folder bs = args.batch_size optimize = args.optimize tta_pre = args.tta_pre tta_post = args.tta_post merge = args.merge min_size = args.min_size thresh = args.thresh name = args.name if model == 'unet': model = smp.Unet(encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None) if model == 'fpn': model = smp.FPN( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) if model == 'pspnet': model = smp.PSPNet( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) if model == 'linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet') test_df = get_dataset(train=False) test_df = prepare_dataset(test_df) test_ids = test_df['Image_Label'].apply( lambda x: x.split('_')[0]).drop_duplicates().values test_dataset = CloudDataset( df=test_df, datatype='test', img_ids=test_ids, transforms=valid1(), preprocessing=get_preprocessing(preprocessing_fn)) test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False) val_df = get_dataset(train=True) val_df = prepare_dataset(val_df) _, val_ids = get_train_test(val_df) valid_dataset = CloudDataset( df=val_df, datatype='train', img_ids=val_ids, transforms=valid1(), preprocessing=get_preprocessing(preprocessing_fn)) valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False) model.load_state_dict(torch.load(loc)['model_state_dict']) class_params = { 0: (thresh, min_size), 1: (thresh, min_size), 2: (thresh, min_size), 3: (thresh, min_size) } if optimize: print("OPTIMIZING") print(tta_pre) if tta_pre: opt_model = tta.SegmentationTTAWrapper( model, tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 180]) ]), merge_mode=merge) else: opt_model = model tta_runner = SupervisedRunner() print("INFERRING ON VALID") tta_runner.infer( model=opt_model, loaders={'valid': valid_loader}, callbacks=[InferCallback()], verbose=True, ) valid_masks = [] probabilities = np.zeros((4 * len(valid_dataset), 350, 525)) for i, (batch, output) in enumerate( tqdm( zip(valid_dataset, tta_runner.callbacks[0].predictions["logits"]))): _, mask = batch for m in mask: if m.shape != (350, 525): m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) valid_masks.append(m) for j, probability in enumerate(output): if probability.shape != (350, 525): probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) probabilities[(i * 4) + j, :, :] = probability print("RUNNING GRID SEARCH") for class_id in range(4): print(class_id) attempts = [] for t in range(30, 70, 5): t /= 100 for ms in [7500, 10000, 12500, 15000, 175000]: masks = [] for i in range(class_id, len(probabilities), 4): probability = probabilities[i] predict, num_predict = post_process( sigmoid(probability), t, ms) masks.append(predict) d = [] for i, j in zip(masks, valid_masks[class_id::4]): if (i.sum() == 0) & (j.sum() == 0): d.append(1) else: d.append(dice(i, j)) attempts.append((t, ms, np.mean(d))) attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice']) attempts_df = attempts_df.sort_values('dice', ascending=False) print(attempts_df.head()) best_threshold = attempts_df['threshold'].values[0] best_size = attempts_df['size'].values[0] class_params[class_id] = (best_threshold, best_size) del opt_model del tta_runner del valid_masks del probabilities gc.collect() if tta_post: model = tta.SegmentationTTAWrapper(model, tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 180]) ]), merge_mode=merge) else: model = model print(tta_post) runner = SupervisedRunner() runner.infer( model=model, loaders={'test': test_loader}, callbacks=[InferCallback()], verbose=True, ) encoded_pixels = [] image_id = 0 for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])): for i, prob in enumerate(image): if prob.shape != (350, 525): prob = cv2.resize(prob, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) predict, num_predict = post_process(sigmoid(prob), class_params[image_id % 4][0], class_params[image_id % 4][1]) if num_predict == 0: encoded_pixels.append('') else: r = mask2rle(predict) encoded_pixels.append(r) image_id += 1 test_df['EncodedPixels'] = encoded_pixels test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
print('==> Loading data..') data = pd.read_pickle('data/data_train.pkl') data_met = pd.read_pickle('data/data_train_met.pkl') dataset = CloudDataset(data, data_met) train_loader = DataLoader(dataset, batch_size=Cfg.batch_size, shuffle=True, drop_last=True, num_workers=Cfg.num_workers) ###############Load Data################################################### ###############Building Model############################################## print('==> Building model..') import segmentation_models_pytorch as smp in_channels = 46 cloud2cloud = smp.PSPNet(encoder_name='vgg19_bn', classes=4) cloud2cloud.encoder.features[0] = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) if Cfg.checkpoint: cloud2cloud.load_state_dict(torch.load(Cfg.checkpoint)) cloud2cloud = cloud2cloud.cuda() ###############Building Model############################################## ###############Building Optim############################################## optim = torch.optim.Adam(cloud2cloud.parameters(), lr=Cfg.lr)
def get_t_net_model(): model = smp.PSPNet('resnet50', classes=3) return model
# Set system logger system_logger = get_logger(name='train', file_path=os.path.join(PERFORMANCE_RECORD_DIR, 'train_log.log')) # Unet / PSPNet / DeepLabV3Plus if MODEL == 'unet': model = smp.Unet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) elif MODEL == 'pspnet': model = smp.PSPNet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) elif MODEL == 'deeplabv3plus': model = smp.DeepLabV3Plus( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION, ) elif MODEL == 'pannet': model = smp.PAN( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(CLASSES), activation=ACTIVATION,
def get_model(num_classes): # model = UNet( num_classes = num_classes ) # model = segnet( n_classes = num_classes ) model = smp.PSPNet(classes=num_classes) model.train() return model.to(device)
def get_t_net_model(): model = smp.PSPNet('resnet50', classes=3) # for p in model.encoder.parameters(): # p.requires_grad = False return model
num_classes = 16 model_name = arg.model learning_rate = arg.l_rate num_epochs = arg.n_epoch batch_size = arg.batch_size history = collections.defaultdict(list) model_dict = { 'unet':UNet( num_classes = num_classes).train().to(device), 'segnet':segnet( n_classes = num_classes ).train().to(device), 'pspnet':smp.PSPNet(classes= num_classes ).train().to(device), } net = model_dict[model_name] if torch.cuda.device_count() > 1: print("using multi gpu") net = torch.nn.DataParallel(net,device_ids = [0, 1, 2, 3]) else: print('using one gpu') # if True: # print("The ckp has been loaded sucessfully ") # net = torch.load("./model/unet_2019-07-23.pth") # load the pretrained model criterion = FocalLoss2d().to(device) train_loader, val_loader = get_dataset_loaders(5, batch_size) opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
def get_model(config): """ """ arch = config.MODEL.ARCHITECTURE backbone = config.MODEL.BACKBONE encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM in_channels = config.MODEL.IN_CHANNELS n_classes = len(config.INPUT.CLASSES) activation = config.MODEL.ACTIVATION # unet specific decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None if arch == 'unet': model = smp.Unet( encoder_name=backbone, encoder_weights=encoder_weights, decoder_channels=config.MODEL.UNET_DECODER_CHANNELS, decoder_attention_type=decoder_attention_type, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'fpn': model = smp.FPN( encoder_name=backbone, encoder_weights=encoder_weights, decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'pan': model = smp.PAN( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'pspnet': model = smp.PSPNet( encoder_name=backbone, encoder_weights=encoder_weights, psp_dropout=config.MODEL.PSPNET_DROPOUT, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'deeplabv3': model = smp.DeepLabV3( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'linknet': model = smp.Linknet( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) else: raise ValueError() model = torch.nn.DataParallel(model) if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none': # load weight from file model.load_state_dict( torch.load( config.MODEL.WEIGHT, map_location=torch.device('cpu') ) ) model = model.to(config.MODEL.DEVICE) return model