Example #1
0
def get_model(name, classification_head, model_weights_path=None):
    if name == 'unet34':
        return smp.Unet('resnet34', encoder_weights='imagenet')
    elif name == 'unet18':
        print('classification_head:', classification_head)
        if classification_head:
            aux_params = dict(
                pooling='max',  # one of 'avg', 'max'
                dropout=0.1,  # dropout ratio, default is None
                activation='sigmoid',  # activation function, default is None
                classes=1,  # define number of output labels
            )
            return smp.Unet('resnet18',
                            aux_params=aux_params,
                            encoder_weights=None,
                            encoder_depth=2,
                            decoder_channels=(256, 128))
        else:
            return smp.Unet('resnet18',
                            encoder_weights='imagenet',
                            encoder_depth=2,
                            decoder_channels=(256, 128))
    elif name == 'unet50':
        return smp.Unet('resnet50', encoder_weights='imagenet')
    elif name == 'unet101':
        return smp.Unet('resnet101', encoder_weights='imagenet')
    elif name == 'linknet34':
        return smp.Linknet('resnet34', encoder_weights='imagenet')
    elif name == 'linknet50':
        return smp.Linknet('resnet50', encoder_weights='imagenet')
    elif name == 'fpn34':
        return smp.FPN('resnet34', encoder_weights='imagenet')
    elif name == 'fpn50':
        return smp.FPN('resnet50', encoder_weights='imagenet')
    elif name == 'fpn101':
        return smp.FPN('resnet101', encoder_weights='imagenet')
    elif name == 'pspnet34':
        return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1)
    elif name == 'pspnet50':
        return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1)
    elif name == 'fpn50_season':
        from clearcut_research.pytorch import FPN_double_output
        return FPN_double_output('resnet50', encoder_weights='imagenet')
    elif name == 'fpn50_satellite':
        fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None)
        fpn_resnet50.encoder = get_satellite_pretrained_resnet(
            model_weights_path)
        return fpn_resnet50
    elif name == 'fpn50_multiclass':
        return smp.FPN('resnet50',
                       encoder_weights='imagenet',
                       classes=3,
                       activation='softmax')
    else:
        raise ValueError("Unknown network")
Example #2
0
def get_basenet(basenet,
                backbone,
                encoder_weights,
                classes,
                decoder_channels,
                activation='sigmoid'):
    if basenet == 'fpn':
        return smp.FPN(backbone,
                       encoder_weights=encoder_weights,
                       classes=classes,
                       activation=activation)
    elif basenet == 'psp':
        return smp.PSPNet(backbone,
                          encoder_weights=encoder_weights,
                          classes=classes,
                          activation=activation)
    elif basenet == 'deeplabv3':
        return smp.DeepLabV3(backbone,
                             encoder_weights=encoder_weights,
                             classes=classes,
                             activation=activation)

    return smp.Unet(backbone,
                    encoder_weights=encoder_weights,
                    encoder_depth=len(decoder_channels),
                    classes=classes,
                    decoder_channels=decoder_channels,
                    activation=activation)
def init(config):
    # ---- Model Initialization  ----
    if config["model"] == "UNet":
        model = smp.Unet(
            activation=None
        )  #UNet2D(n_channels=3, n_classes=1) # #UNet2D(n_channels=1, n_classes=1) #smp.Unet(activation=None)
    elif config["model"] == "PSPNet":
        model = smp.PSPNet(activation=None)
    elif config["model"] == "FPN":
        model = smp.FPN(activation=None)
    elif config["model"] == "Linknet":
        model = smp.Linknet(activation=None)
    else:
        raise Exception('Incorrect model name!')

    # ---- Loss Initialization  ----
    if config["mode"] == 'train':
        if config["loss"] == "DiceBCE":
            loss = LossBinaryDice(dice_weight=config["dice_weight"])
        elif config["loss"] == "FocalTversky":
            loss = FocalTverskyLoss()
        elif config["loss"] == "Focal":
            loss = FocalLoss()
        elif config["loss"] == "Tversky":
            loss = TverskyLoss()
        else:
            raise Exception('Incorrect loss name!')

        return model, loss
    else:
        return model
Example #4
0
    def __init__(self, debug=False):
        super().__init__()
        self.PSPNet = smp.PSPNet(encoder_name='resnet34',
                                 encoder_weights='imagenet',
                                 classes=4,
                                 activation=None)

        self.debug = debug
Example #5
0
    def __init__(self, in_channels=24):
        super(Cloud2Cloud, self).__init__()

        self.cloudNet = smp.PSPNet(encoder_name='resnet34',
                                   classes=4,
                                   encoder_weights='imagenet')
        self.cloudNet.encoder.conv1 = torch.nn.Conv2d(in_channels=in_channels,
                                                      out_channels=64,
                                                      kernel_size=(5, 5),
                                                      stride=(2, 2),
                                                      padding=(2, 2),
                                                      bias=False)
        self.metNet = smp.PSPNet(encoder_name='resnet18',
                                 encoder_weights='imagenet')
        self.cloud_encoder = self.cloudNet.encoder
        self.met_encoder = self.metNet.encoder
        self.cloud_decoder = self.cloudNet.decoder
Example #6
0
    def build_model(self):
        print("Using model: {}".format(self.model_type))
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3,
                                   output_ch=self.output_ch,
                                   t=self.t)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', pretrained=True, classes=self.output_ch)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=self.output_ch)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=self.output_ch)
        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)

        if torch.cuda.is_available():
            self.unet = torch.nn.DataParallel(self.unet)
            self.criterion = self.criterion.cuda()
            self.criterion_stage2 = self.criterion_stage2.cuda()
            self.criterion_stage3 = self.criterion_stage3.cuda()
        self.unet.to(self.device)
Example #7
0
    def __init__(self, model_name, in_channels=3, out_channels=1):
        super(SmpModel16, self).__init__()

        aux_params = dict(
            pooling='max',  # one of 'avg', 'max'
            dropout=0.5,  # dropout ratio, default is None
            activation='softmax',  # activation function, default is None
            classes=out_channels,  # define number of output labels
        )

        if 'unet' in model_name:
            self.model = smp.Unet(
                encoder_name=
                "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_depth=3,
                encoder_weights=
                "imagenet",  # use `imagenet` pretrained weights for encoder initialization
                decoder_channels=[128, 64, 32],  # [256, 128, 64, 32]
                in_channels=
                in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=
                out_channels,  # model output channels (number of classes in your dataset)
                aux_params=aux_params,
            )

        elif 'uplus' in model_name:

            self.model = smp.UnetPlusPlus(
                encoder_name=
                "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_depth=3,
                encoder_weights=
                "imagenet",  # use `imagenet` pretrained weights for encoder initialization
                decoder_channels=[256, 128, 64],
                in_channels=
                in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=
                out_channels,  # model output channels (number of classes in your dataset)
                aux_params=aux_params,
            )

        else:

            self.model = smp.PSPNet(
                encoder_name=
                "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                # encoder_depth=4,
                encoder_weights=
                "imagenet",  # use `imagenet` pretrained weights for encoder initialization
                # decoder_channels=[256, 128, 64, 32],
                in_channels=
                in_channels,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=
                out_channels,  # model output channels (number of classes in your dataset)
                aux_params=aux_params,
            )

        self.down_layer = DownsizeBlock(out_channels, downsize_mode=2)
def prepare_model(backbone='se_resnet50', weight='None', checkpoint_path=None):
    '''
    change channel 3 to channel 1
    original as the following
        # first layer
           model.encoder.layer0.conv1 = nn.Conv2d(3, 64,...)
        # last layer
           model.decoder.final_conv = nn.Conv2d(512, 3,...)
    '''
    if weight == 'None':
        model = smp.PSPNet(backbone,
                           encoder_weights=None,
                           classes=1,
                           activation=None)
    elif weight == 'imagenet':
        model = smp.PSPNet(backbone,
                           encoder_weights='imagenet',
                           classes=1,
                           activation=None)
    elif weight == 'pretrained':
        model = smp.PSPNet(backbone,
                           encoder_weights=None,
                           classes=1,
                           activation=None)
        model.encoder.layer0.conv1 = nn.Conv2d(1,
                                               64,
                                               kernel_size=(7, 7),
                                               stride=(2, 2),
                                               padding=(3, 3),
                                               bias=False)
        model.to(torch.device("cuda:0"))
        state = torch.load(checkpoint_path,
                           map_location=lambda storage, loc: storage)
        model.load_state_dict(state["state_dict"])

    model.encoder.layer0.conv1 = nn.Conv2d(1,
                                           64,
                                           kernel_size=(7, 7),
                                           stride=(2, 2),
                                           padding=(3, 3),
                                           bias=False)

    return model
Example #9
0
def resnet50_PSPNet_noclassification(**kwargs):
    model = smp.PSPNet('resnet50',
                       in_channels=in_channels,
                       classes=classes,
                       activation=activation,
                       **kwargs)
    print("Just segmentation Model args:")
    print("in_channels:%d,classes:%d,activation:%s" %
          (in_channels, classes, activation))
    print("kwargs", kwargs)
    return model
 def PSP(self, img_ch, output_ch):
     return smp.PSPNet(encoder_name=self.encoder,
                       encoder_weights=self.en_weights,
                       encoder_depth=3,
                       psp_out_channels=512,
                       psp_use_batchnorm=False,
                       psp_dropout=0.2,
                       in_channels=img_ch,
                       classes=output_ch,
                       activation=None,
                       upsampling=8,
                       aux_params=None)
Example #11
0
def resnet34_psp(num_classes):
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'

    model = smp.PSPNet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=num_classes,
        activation=None,
    )

    return model
Example #12
0
def build_model(configuration):
    model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+']
    if configuration.Model.model_name.lower() == 'unet':
        return smp.Unet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes,
            decoder_attention_type=None,
        )
    if configuration.Model.model_name.lower() == 'linknet':
        return smp.Linknet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pspnet':
        return smp.PSPNet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'fpn':
        return smp.FPN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pan':
        return smp.PAN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3+':
        return smp.DeepLabV3Plus(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3':
        return smp.DeepLabV3(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    raise KeyError(f'Model should be one of {model_list}')
Example #13
0
def create_segmentation_models(encoder,
                               arch,
                               num_classes=4,
                               encoder_weights=None,
                               activation=None):
    '''
    segmentation_models_pytorch https://github.com/qubvel/segmentation_models.pytorch
    has following architectures: 
    - Unet
    - Linknet
    - FPN
    - PSPNet
    encoders: A lot! see the above github page.

    Deeplabv3+ https://github.com/jfzhang95/pytorch-deeplab-xception
    has for encoders:
    - resnet (resnet101)
    - mobilenet 
    - xception
    - drn
    '''
    if arch == "Unet":
        return smp.Unet(encoder,
                        encoder_weights=encoder_weights,
                        classes=num_classes,
                        activation=activation)
    elif arch == "Linknet":
        return smp.Linknet(encoder,
                           encoder_weights=encoder_weghts,
                           classes=num_classes,
                           activation=activation)
    elif arch == "FPN":
        return smp.FPN(encoder,
                       encoder_weights=encoder_weghts,
                       classes=num_classes,
                       activation=activation)
    elif arch == "PSPNet":
        return smp.PSPNet(encoder,
                          encoder_weights=encoder_weghts,
                          classes=num_classes,
                          activation=activation)
    elif arch == "deeplabv3plus":
        if deeplabv3plus_PATH in os.environ:
            sys.path.append(os.environ[deeplabv3plus_PATH])
            from modeling.deeplab import DeepLab
            return DeepLab(encoder, num_classes=4)
        else:
            raise ValueError('Set deeplabv3plus path by environment variable.')
    else:
        raise ValueError(
            'arch {} is not found, set the correct arch'.format(arch))
        sys.exit()
Example #14
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)

        elif self.model_type == 'unet_resnet34':
            # self.unet = Unet(backbone_name='resnet34', classes=1)
            self.unet = smp.Unet('resnet34',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet50':
            self.unet = smp.Unet('resnet50',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_se_resnext50_32x4d':
            self.unet = smp.Unet('se_resnext50_32x4d',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_densenet121':
            self.unet = smp.Unet('densenet121',
                                 encoder_weights='imagenet',
                                 activation=None)
        elif self.model_type == 'unet_resnet34_t':
            self.unet = Unet_t('resnet34',
                               encoder_weights='imagenet',
                               activation=None,
                               use_ConvTranspose2d=True)
        elif self.model_type == 'unet_resnet34_oct':
            self.unet = OctaveUnet('resnet34',
                                   encoder_weights='imagenet',
                                   activation=None)

        elif self.model_type == 'pspnet_resnet34':
            self.unet = smp.PSPNet('resnet34',
                                   encoder_weights='imagenet',
                                   classes=1,
                                   activation=None)
        elif self.model_type == 'linknet':
            self.unet = LinkNet34(num_classes=1)
        elif self.model_type == 'deeplabv3plus':
            self.unet = DeepLabV3Plus(model_backbone='res50_atrous',
                                      num_classes=1)
            # self.unet = DeepLabV3Plus(num_classes=1)

        # print('build model done!')

        self.unet.to(self.device)
Example #15
0
def get_model(name='fpn50', model_weights_path=None):
    if name == 'unet34':
        return smp.Unet('resnet34', encoder_weights='imagenet')
    elif name == 'unet50':
        return smp.Unet('resnet50', encoder_weights='imagenet')
    elif name == 'unet101':
        return smp.Unet('resnet101', encoder_weights='imagenet')
    elif name == 'linknet34':
        return smp.Linknet('resnet34', encoder_weights='imagenet')
    elif name == 'linknet50':
        return smp.Linknet('resnet50', encoder_weights='imagenet')
    elif name == 'fpn34':
        return smp.FPN('resnet34', encoder_weights='imagenet')
    elif name == 'fpn50':
        return smp.FPN('resnet50', encoder_weights='imagenet')
    elif name == 'fpn101':
        return smp.FPN('resnet101', encoder_weights='imagenet')
    elif name == 'pspnet34':
        return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1)
    elif name == 'pspnet50':
        return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1)
    elif name == 'fpn50_season':
        from clearcut_research.pytorch import FPN_double_output
        return FPN_double_output('resnet50', encoder_weights='imagenet')
    elif name == 'fpn50_satellite':
        fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None)
        fpn_resnet50.encoder = get_satellite_pretrained_resnet(
            model_weights_path)
        return fpn_resnet50
    elif name == 'fpn50_multiclass':
        return smp.FPN('resnet50',
                       encoder_weights='imagenet',
                       classes=3,
                       activation='softmax')
    else:
        raise ValueError("Unknown network")
Example #16
0
def inference(config_file, model_file, input_data, output):

    debug = False

    config = json.load(open(f"{config_file}"))

    encoder = config["arch"]["args"]["encoder"]
    encoder_weights = config["arch"]["args"]["encoder_weights"]

    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        encoder, encoder_weights)
    model = smp.PSPNet(
        encoder_name=encoder,
        encoder_weights=encoder_weights,
        classes=1,
        activation=config["training"]["activation"],
    )
    model.load_state_dict(
        torch.load(f"{model_file}",
                   map_location=torch.device(DEVICE))["model_state_dict"])
    model.to(DEVICE)

    infer_dataset = WeedDataset(
        [input_data],
        weed_label=config["data"]["weed_label"],
        augmentation=aug.get_validation_augmentations(config["data"]["aug"]),
        preprocessing=aug.get_preprocessing(preprocessing_fn),
    )

    for i in range(len(infer_dataset)):
        image, gt_mask, id = infer_dataset.get_img_and_props(i)

        gt_mask = gt_mask.squeeze()

        x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
        pr_mask = model.predict(x_tensor)
        pr_mask = (pr_mask.squeeze().cpu().numpy().round())

        if debug:
            visualize(image=image.squeeze().swapaxes(0, 1).swapaxes(1, 2),
                      ground_truth_mask=gt_mask,
                      predicted_mask=pr_mask)

        if id["rotate"]:
            pr_mask = cv2.rotate(pr_mask, cv2.ROTATE_90_COUNTERCLOCKWISE)
        pr_mask = cv2.resize(pr_mask, (id["img_height"], id["img_width"]),
                             interpolation=cv2.INTER_LINEAR)
        save_mask(pr_mask, output, os.path.basename(id["img_id"]))
Example #17
0
def get_model(num_classes, model_name):
    if model_name == "UNet":
        print("using UNet")
        model = smp.Unet(encoder_name='resnet50', classes=num_classes, activation='sigmoid')
        if args.num_channels >3:
            weight = model.encoder.conv1.weight.clone()
            model.encoder.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
            with torch.no_grad():
                print("using 4c")
                model.encoder.conv1.weight[:, :3] = weight
                model.encoder.conv1.weight[:, 3] = model.encoder.conv1.weight[:, 0]
        return model
    elif model_name == "PSPNet":
        print("using PSPNet")
        model = smp.PSPNet(encoder_name="resnet50", classes=num_classes, activation='softmax')
        if args.num_channels > 3:
            weight = model.encoder.conv1.weight.clone()
            model.encoder.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
            with torch.no_grad():
                print("using 4c")
                model.encoder.conv1.weight[:, :3] = weight
                model.encoder.conv1.weight[:, 3] = model.encoder.conv1.weight[:, 0]
        return model
    elif model_name == "FPN":
        print("using FPN")
        model = smp.FPN(encoder_name='resnet50', classes=num_classes)
        if args.num_channels > 3:
            weight = model.encoder.conv1.weight.clone()
            model.encoder.conv1 = torch.nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
            with torch.no_grad():
                print("using 4c")
                model.encoder.conv1.weight[:, :3] = weight
                model.encoder.conv1.weight[:, 3] = model.encoder.conv1.weight[:, 0]
        return model
    elif model_name == "AlbuNet":
        print("using AlbuNet")
        model = AlbuNet(pretrained=True, num_classes=num_classes)
        return model
    elif model_name == "YpUnet":
        print("using YpUnet_ASPP")
        model = UNet(num_classes=num_classes)
        return model
    else:
        print("error in model")
        return None
Example #18
0
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'

    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'

    if arch == "Unet": model = smp.Unet(**kwargs)
    elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch == "MAnet": model = smp.MAnet(**kwargs)
    elif arch == "FPN": model = smp.FPN(**kwargs)
    elif arch == "PAN": model = smp.PAN(**kwargs)
    elif arch == "PSPNet": model = smp.PSPNet(**kwargs)
    elif arch == "Linknet": model = smp.Linknet(**kwargs)
    elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError

    setattr(model, 'kwargs', kwargs)
    return model
Example #19
0
def SkinLesionModel(model, pretrained=True):
    models_zoo = {
        'deeplabv3plus':
        smp.DeepLabV3Plus('resnet101',
                          encoder_weights='imagenet',
                          aux_params=None),
        'deeplabv3plus_resnext':
        smp.DeepLabV3Plus('resnext101_32x8d',
                          encoder_weights='imagenet',
                          aux_params=None),
        'pspnet':
        smp.PSPNet('resnet101', encoder_weights='imagenet', aux_params=None),
        'unetplusplus':
        smp.UnetPlusPlus('resnet101',
                         encoder_weights='imagenet',
                         aux_params=None),
    }
    net = models_zoo.get(model)
    if net is None:
        raise Warning('Wrong Net Name!!')
    return net
Example #20
0
def get_model(encoder='resnet18',
              type='unet',
              encoder_weights='imagenet',
              classes=4):
    # My own simple wrapper around smp
    if type == 'unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    elif type == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    elif type == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    elif type == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    else:
        raise "weird architecture"
    print(f"Training on {type} architecture with {encoder} encoder")
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        encoder, encoder_weights)
    return model, preprocessing_fn
Example #21
0
def load_model(net, ENCODER, ENCODER_WEIGHTS, ACTIVATION):
    if net == "Unet":
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=ACTIVATION,
        )
    elif net == "FPN":
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=ACTIVATION,
        )
    elif net == "PSPNet":
        model = smp.PSPNet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=ACTIVATION,
        )
    return model
Example #22
0
    def __init__(self,
                 encoder,
                 encoder_weights,
                 classes,
                 activation,
                 learning_rate=1e-3,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.classes = classes

        if self.hparams.architecture == 'fpn':
            self.model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=len(classes),
                activation=activation,
            )
        elif self.hparams.architecture == 'pan':
            self.model = smp.PAN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=len(classes),
                activation=activation,
            )
        elif self.hparams.architecture == 'pspnet':
            self.model = smp.PSPNet(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=len(classes),
                activation=activation,
            )
        else:
            raise NameError('')

        self.loss = smp.utils.losses.DiceLoss()
Example #23
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--loc', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--optimize', type=bool, default=False)
    parser.add_argument('--tta_pre', type=bool, default=False)
    parser.add_argument('--tta_post', type=bool, default=False)
    parser.add_argument('--merge', type=str, default='mean')
    parser.add_argument('--min_size', type=int, default=10000)
    parser.add_argument('--thresh', type=float, default=0.5)
    parser.add_argument('--name', type=str)

    args = parser.parse_args()
    encoder = args.encoder
    model = args.model
    loc = args.loc
    data_folder = args.data_folder
    bs = args.batch_size
    optimize = args.optimize
    tta_pre = args.tta_pre
    tta_post = args.tta_post
    merge = args.merge
    min_size = args.min_size
    thresh = args.thresh
    name = args.name

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights='imagenet',
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')

    test_df = get_dataset(train=False)
    test_df = prepare_dataset(test_df)
    test_ids = test_df['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    test_dataset = CloudDataset(
        df=test_df,
        datatype='test',
        img_ids=test_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

    val_df = get_dataset(train=True)
    val_df = prepare_dataset(val_df)
    _, val_ids = get_train_test(val_df)
    valid_dataset = CloudDataset(
        df=val_df,
        datatype='train',
        img_ids=val_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

    model.load_state_dict(torch.load(loc)['model_state_dict'])

    class_params = {
        0: (thresh, min_size),
        1: (thresh, min_size),
        2: (thresh, min_size),
        3: (thresh, min_size)
    }

    if optimize:
        print("OPTIMIZING")
        print(tta_pre)
        if tta_pre:
            opt_model = tta.SegmentationTTAWrapper(
                model,
                tta.Compose([
                    tta.HorizontalFlip(),
                    tta.VerticalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ]),
                merge_mode=merge)
        else:
            opt_model = model
        tta_runner = SupervisedRunner()
        print("INFERRING ON VALID")
        tta_runner.infer(
            model=opt_model,
            loaders={'valid': valid_loader},
            callbacks=[InferCallback()],
            verbose=True,
        )

        valid_masks = []
        probabilities = np.zeros((4 * len(valid_dataset), 350, 525))
        for i, (batch, output) in enumerate(
                tqdm(
                    zip(valid_dataset,
                        tta_runner.callbacks[0].predictions["logits"]))):
            _, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[(i * 4) + j, :, :] = probability

        print("RUNNING GRID SEARCH")
        for class_id in range(4):
            print(class_id)
            attempts = []
            for t in range(30, 70, 5):
                t /= 100
                for ms in [7500, 10000, 12500, 15000, 175000]:
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            sigmoid(probability), t, ms)
                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))

                    attempts.append((t, ms, np.mean(d)))

            attempts_df = pd.DataFrame(attempts,
                                       columns=['threshold', 'size', 'dice'])

            attempts_df = attempts_df.sort_values('dice', ascending=False)
            print(attempts_df.head())
            best_threshold = attempts_df['threshold'].values[0]
            best_size = attempts_df['size'].values[0]

            class_params[class_id] = (best_threshold, best_size)

        del opt_model
        del tta_runner
        del valid_masks
        del probabilities
    gc.collect()

    if tta_post:
        model = tta.SegmentationTTAWrapper(model,
                                           tta.Compose([
                                               tta.HorizontalFlip(),
                                               tta.VerticalFlip(),
                                               tta.Rotate90(angles=[0, 180])
                                           ]),
                                           merge_mode=merge)
    else:
        model = model
    print(tta_post)

    runner = SupervisedRunner()
    runner.infer(
        model=model,
        loaders={'test': test_loader},
        callbacks=[InferCallback()],
        verbose=True,
    )

    encoded_pixels = []
    image_id = 0

    for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])):
        for i, prob in enumerate(image):
            if prob.shape != (350, 525):
                prob = cv2.resize(prob,
                                  dsize=(525, 350),
                                  interpolation=cv2.INTER_LINEAR)
            predict, num_predict = post_process(sigmoid(prob),
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1

    test_df['EncodedPixels'] = encoded_pixels
    test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
Example #24
0
    print('==> Loading data..')
    data = pd.read_pickle('data/data_train.pkl')
    data_met = pd.read_pickle('data/data_train_met.pkl')
    dataset = CloudDataset(data, data_met)
    train_loader = DataLoader(dataset,
                              batch_size=Cfg.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=Cfg.num_workers)
    ###############Load Data###################################################

    ###############Building Model##############################################
    print('==> Building model..')
    import segmentation_models_pytorch as smp
    in_channels = 46
    cloud2cloud = smp.PSPNet(encoder_name='vgg19_bn', classes=4)
    cloud2cloud.encoder.features[0] = torch.nn.Conv2d(in_channels=in_channels,
                                                      out_channels=64,
                                                      kernel_size=(3, 3),
                                                      stride=(1, 1),
                                                      padding=(1, 1),
                                                      bias=False)

    if Cfg.checkpoint:
        cloud2cloud.load_state_dict(torch.load(Cfg.checkpoint))

    cloud2cloud = cloud2cloud.cuda()
    ###############Building Model##############################################

    ###############Building Optim##############################################
    optim = torch.optim.Adam(cloud2cloud.parameters(), lr=Cfg.lr)
Example #25
0
def get_t_net_model():
    model = smp.PSPNet('resnet50', classes=3)
    return model
Example #26
0
 # Set system logger
 system_logger = get_logger(name='train', 
                            file_path=os.path.join(PERFORMANCE_RECORD_DIR, 'train_log.log'))
 
 # Unet / PSPNet / DeepLabV3Plus
 if MODEL == 'unet':
     model = smp.Unet(
         encoder_name=ENCODER, 
         encoder_weights=ENCODER_WEIGHTS, 
         classes=len(CLASSES), 
         activation=ACTIVATION,
     )
 elif MODEL == 'pspnet':
     model = smp.PSPNet(
         encoder_name=ENCODER, 
         encoder_weights=ENCODER_WEIGHTS, 
         classes=len(CLASSES), 
         activation=ACTIVATION,
     )
 elif MODEL == 'deeplabv3plus':
     model = smp.DeepLabV3Plus(
         encoder_name=ENCODER, 
         encoder_weights=ENCODER_WEIGHTS, 
         classes=len(CLASSES), 
         activation=ACTIVATION,
     )
 elif  MODEL == 'pannet':
     model = smp.PAN(
         encoder_name=ENCODER,
         encoder_weights=ENCODER_WEIGHTS,
         classes=len(CLASSES),
         activation=ACTIVATION,
def get_model(num_classes):
    #   model = UNet( num_classes = num_classes )
    #   model = segnet(  n_classes = num_classes )
    model = smp.PSPNet(classes=num_classes)
    model.train()
    return model.to(device)
Example #28
0
def get_t_net_model():
    model = smp.PSPNet('resnet50', classes=3)
    # for p in model.encoder.parameters():
    # p.requires_grad = False
    return model
Example #29
0
  


    num_classes = 16
    model_name = arg.model
    learning_rate = arg.l_rate
    num_epochs = arg.n_epoch
    batch_size = arg.batch_size


    history = collections.defaultdict(list)
    model_dict = {
                'unet':UNet( num_classes = num_classes).train().to(device),
                'segnet':segnet(  n_classes = num_classes ).train().to(device),
                'pspnet':smp.PSPNet(classes= num_classes ).train().to(device),
                }

    net = model_dict[model_name]
    if torch.cuda.device_count() > 1:
        print("using multi gpu")
        net = torch.nn.DataParallel(net,device_ids = [0, 1, 2, 3])
    else:
        print('using one gpu')

    # if True:
    #     print("The ckp has been loaded sucessfully ")
    #     net = torch.load("./model/unet_2019-07-23.pth") # load the pretrained model
    criterion = FocalLoss2d().to(device)
    train_loader, val_loader = get_dataset_loaders(5, batch_size)
    opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
Example #30
0
def get_model(config):
    """
    """
    arch = config.MODEL.ARCHITECTURE
    backbone = config.MODEL.BACKBONE
    encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM
    in_channels = config.MODEL.IN_CHANNELS
    n_classes = len(config.INPUT.CLASSES)
    activation = config.MODEL.ACTIVATION

    # unet specific
    decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None

    if arch == 'unet':
        model = smp.Unet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_channels=config.MODEL.UNET_DECODER_CHANNELS,
            decoder_attention_type=decoder_attention_type,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'fpn':
        model = smp.FPN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pan':
        model = smp.PAN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pspnet':
        model = smp.PSPNet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            psp_dropout=config.MODEL.PSPNET_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'deeplabv3':
        model = smp.DeepLabV3(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'linknet':
        model = smp.Linknet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    else:
        raise ValueError()

    model = torch.nn.DataParallel(model)

    if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none':
        # load weight from file
        model.load_state_dict(
            torch.load(
                config.MODEL.WEIGHT,
                map_location=torch.device('cpu')
            )
        )

    model = model.to(config.MODEL.DEVICE)
    return model