def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, ): super().__init__() encoder_depth = len(decoder_channels) self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.decoder = UnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, ) # self.segmentation_head = SegmentationHead_deepsuper( # in_channels=decoder_channels[-1], # out_channels=classes, # activation=activation, # kernel_size=3, # ) self.segmentation_head = SegmentationHead( in_channels=decoder_channels[-1], out_channels=classes, activation=activation, kernel_size=3, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params ) else: self.classification_head = None self.name = "unetplusplus-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_pyramid_channels: int = 256, decoder_segmentation_channels: int = 128, decoder_merge_policy: str = "add", decoder_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, upsampling: int = 4, aux_params: Optional[dict] = None, ): super().__init__() self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.decoder = FPNDecoder( encoder_channels=self.encoder.out_channels, encoder_depth=encoder_depth, pyramid_channels=decoder_pyramid_channels, segmentation_channels=decoder_segmentation_channels, dropout=decoder_dropout, merge_policy=decoder_merge_policy, ) self.segmentation_head = SegmentationHead( in_channels=self.decoder.out_channels, out_channels=classes, activation=activation, kernel_size=1, upsampling=upsampling, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None self.name = "fpn-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", encoder_output_stride: int = 16, decoder_channels: int = 256, decoder_atrous_rates: tuple = (12, 24, 36), in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, upsampling: int = 4, aux_params: Optional[dict] = None, ): super().__init__() if encoder_output_stride not in [8, 16]: raise ValueError( "Encoder output stride should be 8 or 16, got {}".format( encoder_output_stride)) self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, output_stride=encoder_output_stride, ) self.decoder = DeepLabV3PlusDecoder( encoder_channels=self.encoder.out_channels, out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride, ) self.segmentation_head = SegmentationHead( in_channels=self.decoder.out_channels, out_channels=classes, activation=activation, kernel_size=1, upsampling=upsampling, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None
def __init__( self, encoder_name: str = "resnet34", encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, psp_use_batchnorm: bool = True, psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, upsampling: int = 8, aux_params: Optional[dict] = None, ): super().__init__() self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.decoder = PSPDecoder( encoder_channels=self.encoder.out_channels, use_batchnorm=psp_use_batchnorm, out_channels=psp_out_channels, dropout=psp_dropout, ) self.segmentation_head = SegmentationHead( in_channels=psp_out_channels, out_channels=classes, kernel_size=3, activation=activation, upsampling=upsampling, ) if aux_params: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params ) else: self.classification_head = None self.name = "psp-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: str = "imagenet", decoder_use_batchnorm: bool = True, decoder_channels: List[int] = (256, 128, 64, 32), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, ): super().__init__() self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth - 1, use_batchnorm=decoder_use_batchnorm, center=True, # attention\conv\Identity attention_type=decoder_attention_type, ) self.segmentation_head = SegmentationHead( in_channels=decoder_channels[-1], out_channels=classes, activation=activation, kernel_size=3, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None self.name = "u-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "timm-efficientnet-b5", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_channels: List[int] = (256, 128, 64, 32, 16), squeeze_ratio: int = 1, expansion_ratio: int = 1, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, ): super().__init__() self.classes = classes self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.decoder = EfficientUnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, squeeze_ratio=squeeze_ratio, expansion_ratio=expansion_ratio, ) self.segmentation_head = SegmentationHead( in_channels=decoder_channels[-1], out_channels=classes, activation=activation, kernel_size=3, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None self.name = "EfficientUNet++-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "resnet34", encoder_weights: Optional[str] = "imagenet", encoder_output_stride: int = 16, decoder_channels: int = 32, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, ): super().__init__() if encoder_output_stride not in [16, 32]: raise ValueError("PAN support output stride 16 or 32, got {}".format(encoder_output_stride)) self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=5, weights=encoder_weights, output_stride=encoder_output_stride, ) self.decoder = PANDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, ) self.segmentation_head = SegmentationHead( in_channels=decoder_channels, out_channels=classes, activation=activation, kernel_size=3, upsampling=upsampling, ) if aux_params is not None: self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None self.name = "pan-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: str = "imagenet", in_channels: int = 3, classes: int = 1, aux_params: Optional[dict] = None, ): super().__init__() self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.num_classes = classes self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) self.name = "c-{}".format(encoder_name) init.initialize_head(self.classification_head)
def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, ): super().__init__() self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, ) self.decoder = LinknetDecoder( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, prefinal_channels=32, use_batchnorm=decoder_use_batchnorm, ) self.segmentation_head = SegmentationHead(in_channels=32, out_channels=classes, activation=activation, kernel_size=1) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None self.name = "link-{}".format(encoder_name) self.initialize()
def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_channels: int = 256, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, upsampling: int = 8, aux_params: Optional[dict] = None, ): super().__init__() self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, output_stride=8, ) self.decoder = DeepLabV3Decoder( in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels, ) self.segmentation_head = SegmentationHead( in_channels=self.decoder.out_channels, out_channels=classes, activation=activation, kernel_size=1, upsampling=upsampling, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params) else: self.classification_head = None
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=1) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, shuffle=True, num_workers=1) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bs, shuffle=False, num_workers=1) model = smp.DeepLabV3Plus(encoder_name='resnet101').to(device) model.classification_head = ClassificationHead( in_channels=model.encoder.out_channels[-1], classes=clusters[i]) model = model.to(device) # Fine tune from the scratch lr = 0.001 max_val_dice = -1 bce_loss = nn.BCEWithLogitsLoss() ce_loss = nn.CrossEntropyLoss() adam_optimizer = Adam(model.parameters(), lr=lr) #,weight_decay=1e-6) m = 0 s = time.time() for epoch in range(50): train(epoch, model, 1) max_val_dice, m = validate(model, model_name, max_val_dice, 1, m) if m >= 5: break