Ejemplo n.º 1
0
    def __init__(
        self,
        encoder_name='resnet34',
        encoder_weights='imagenet',
        decoder_pyramid_channels=256,
        decoder_segmentation_channels=128,
        classes=1,
        dropout=0.2,
        activation='sigmoid',
    ):
        if 'efficientnet' in encoder_name:
            encoder = EfficientNetEncoder.from_pretrained(encoder_name)
        else:
            encoder = get_encoder(encoder_name,
                                  encoder_weights=encoder_weights)

        decoder = FPNDecoder(
            encoder_channels=encoder.out_shapes,
            pyramid_channels=decoder_pyramid_channels,
            segmentation_channels=decoder_segmentation_channels,
            final_channels=classes,
            dropout=dropout,
        )

        super().__init__(encoder, decoder, activation)

        self.linear = torch.nn.Linear(encoder.out_shapes[0], classes)
        self.name = 'fpn-{}'.format(encoder_name)
Ejemplo n.º 2
0
    def __init__(
        self,
        encoder_name='resnet34',
        encoder_weights='imagenet',
        decoder_use_batchnorm=True,
        decoder_channels=(256, 128, 64, 32, 16),
        classes=1,
        activation='sigmoid',
        center=False,  # usefull for VGG models
        use_oc_module=False,
    ):
        if 'efficientnet' in encoder_name:
            encoder = EfficientNetEncoder.from_pretrained(encoder_name)
        else:
            encoder = get_encoder(encoder_name,
                                  encoder_weights=encoder_weights)

        decoder = UnetDecoder(
            encoder_channels=encoder.out_shapes,
            decoder_channels=decoder_channels,
            final_channels=classes,
            use_batchnorm=decoder_use_batchnorm,
            center=center,
        )

        super().__init__(encoder, decoder, activation)

        self.name = 'u-{}'.format(encoder_name)
Ejemplo n.º 3
0
    def __init__(self,
                 encoder_name='resnet34',
                 encoder_weights='imagenet',
                 decoder_use_batchnorm=True,
                 decoder_channels=(256, 128, 64, 32, 16),
                 classes=1,
                 activation='sigmoid',
                 center=False,  # usefull for VGG models
                 ):
        encoder = get_encoder(
            encoder_name,
            encoder_weights=encoder_weights
        )

        decoder = UnetSCSEDecoder(
            encoder_channels=encoder.out_shapes,
            decoder_channels=decoder_channels,
            final_channels=classes,
            use_batchnorm=decoder_use_batchnorm,
            center=center,
        )

        super().__init__(encoder, decoder, activation)

        self.name = 'uscse-{}'.format(encoder_name)
Ejemplo n.º 4
0
    def __init__(self, cfg, input_shape):

        super().__init__()
        self.emb_dim = 768
        self.pretrained = True
        self.pretrained_trans_model = 'vit_base_patch16_384'
        self.patch_size = 16
        
        self.transformer = Transformer_Encoder(pretrained = True, img_size = 384, pretrained_model = self.pretrained_trans_model, patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias = True)
        self.encoder_name = 'timm-efficientnet-b5'
        self.in_channels = 3
        self.encoder_depth = 5
        self.encoder_weights = 'noisy-student'
        
        self.conv_encoder = get_encoder(self.encoder_name,
                in_channels=self.in_channels,
                depth=self.encoder_depth,
                weights=self.encoder_weights)
        
        self.conv_channels = self.conv_encoder.out_channels
        self.conv_final = nn.ModuleList(
            [nn.Conv2d(self.conv_channels[i],self.emb_dim,3,stride = 2, padding = 1) for i in range(1,len(self.conv_channels))]
        )
        self.names = ["p"+str(i+2) for i in range(5)]
        self.resize =  Resize((384,384))
        self.Wq = nn.Linear(self.emb_dim, self.emb_dim, bias = False)
        self.Wk = nn.Linear(self.emb_dim, self.emb_dim, bias = False)
Ejemplo n.º 5
0
    def __init__(
            self,
            encoder_name='resnet34',
            encoder_weights='imagenet',
            decoder_use_batchnorm=True,
            decoder_channels=(256, 128, 64, 32, 16),
            classes=1,
            activation='sigmoid',
            center=False,  # usefull for VGG models
            pretrained=None):
        encoder = get_encoder(encoder_name, encoder_weights=encoder_weights)

        decoder = UnetDecoder(
            encoder_channels=encoder.out_shapes,
            decoder_channels=decoder_channels,
            final_channels=classes,
            use_batchnorm=decoder_use_batchnorm,
            center=center,
        )

        super().__init__(encoder, decoder, activation)
        if pretrained:
            checkpoint = torch.load(pretrained)['model_state_dict']
            self.load_state_dict(checkpoint)
            print("\n********************************************")
            print(f"Loaded checkpoint: {pretrained}")

        self.name = 'u-{}'.format(encoder_name)
Ejemplo n.º 6
0
    def __init__(
            self,
            encoder_name='resnet34',
            encoder_weights='imagenet',
            group_norm=True,
            decoder_channels=(256, 128, 64, 32, 16),
            classes=1,
            activation='sigmoid',
            center='none',  # usefull for VGG models
            attention_type=None,
            reslink=False,
            multi_task=False):
        assert center in ['none', 'normal', 'aspp']
        assert attention_type in ['none', 'cbam', 'scse']

        print("**" * 50)
        print("Encoder name: \t\t{}".format(encoder_name))
        print("Center: \t\t{}".format(center))
        print("Attention type: \t\t{}".format(attention_type))
        print("Reslink: \t\t{}".format(reslink))

        encoder = get_encoder(encoder_name, encoder_weights=encoder_weights)

        decoder = UnetDecoder(encoder_channels=encoder.out_shapes,
                              decoder_channels=decoder_channels,
                              final_channels=classes,
                              group_norm=group_norm,
                              center=center,
                              attention_type=attention_type,
                              reslink=reslink,
                              multi_task=multi_task)

        super().__init__(encoder, decoder, activation)

        self.name = 'vnet-{}'.format(encoder_name)
Ejemplo n.º 7
0
 def __init__(
         self,
         encoder_name='resnet34',
         encoder_weights='imagenet',
         classes=4,
         activation=None,
 ):
     super().__init__()
     self.encoder = get_encoder(encoder_name, encoder_weights=encoder_weights)
     self.name = encoder_name
     self.avg_pool = nn.AdaptiveAvgPool2d(1)
     self.fc = nn.Linear(self.encoder.out_shapes[0], classes)
Ejemplo n.º 8
0
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: str = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        pretext_classes=-1,
        domain_classes=2,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        self.domain_layer = -2
        self.domain_classification_head = DomainClassifier(
            in_channels=self.encoder.out_channels[self.domain_layer],
            domain_classes=domain_classes)

        if pretext_classes == -1:
            raise ValueError(f'initialize pretext_classes')
        self.pretext_classification_head = PretextClassifier(
            in_channels=self.encoder.out_channels[-1],
            pretext_classes=pretext_classes)

        self.name = "u-{}".format(encoder_name)
        self.initialize()
        return
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        encoder_output_stride: int = 16,
        decoder_channels: int = 256,
        decoder_atrous_rates: tuple = (12, 24, 36),
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[str] = None,
        upsampling: int = 4,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        if encoder_output_stride not in [8, 16]:
            raise ValueError(
                "Encoder output stride should be 8 or 16, got {}".format(
                    encoder_output_stride))

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            output_stride=encoder_output_stride,
        )

        self.decoder = DeepLabV3PlusDecoder(
            encoder_channels=self.encoder.out_channels,
            out_channels=decoder_channels,
            atrous_rates=decoder_atrous_rates,
            output_stride=encoder_output_stride,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=self.decoder.out_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=1,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_pyramid_channels: int = 256,
        decoder_segmentation_channels: int = 128,
        decoder_merge_policy: str = "add",
        decoder_dropout: float = 0.2,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[str] = None,
        upsampling: int = 4,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = FPNDecoder(
            encoder_channels=self.encoder.out_channels,
            encoder_depth=encoder_depth,
            pyramid_channels=decoder_pyramid_channels,
            segmentation_channels=decoder_segmentation_channels,
            dropout=decoder_dropout,
            merge_policy=decoder_merge_policy,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=self.decoder.out_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=1,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "fpn-{}".format(encoder_name)
        self.initialize()
Ejemplo n.º 11
0
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: str = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        in_channels: int = 3,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=None,
        )

        self.xydir_head = EncoderRegressionHead(
            in_channels=self.encoder.out_channels[-1],
            out_channels=2,
        )

        self.height_head = RegressionHead(
            in_channels=decoder_channels[-1],
            out_channels=1,
            kernel_size=3,
        )

        self.mag_head = RegressionHead(
            in_channels=decoder_channels[-1],
            out_channels=1,
            kernel_size=3,
        )

        self.scale_head = ScaleHead()

        self.name = "u-{}".format(encoder_name)
        self.initialize()
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_weights: Optional[str] = "imagenet",
        encoder_depth: int = 3,
        psp_out_channels: int = 512,
        psp_use_batchnorm: bool = True,
        psp_dropout: float = 0.2,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        upsampling: int = 8,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = PSPDecoder(
            encoder_channels=self.encoder.out_channels,
            use_batchnorm=psp_use_batchnorm,
            out_channels=psp_out_channels,
            dropout=psp_dropout,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=psp_out_channels,
            out_channels=classes,
            kernel_size=3,
            activation=activation,
            upsampling=upsampling,
        )

        if aux_params:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "psp-{}".format(encoder_name)
        self.initialize()
Ejemplo n.º 13
0
  def __init__(self, cfg, input_shape):
    super().__init__()

    encoder_name = 'timm-efficientnet-b5'
    in_channels = 3
    encoder_depth = 5
    encoder_weights = 'noisy-student'
    self.encoder = get_encoder(encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights)
    self.channels = self.encoder.out_channels
    self.conv = nn.ModuleList(
        [nn.Conv2d(self.channels[i],256,3,stride = 2, padding = 1) for i in range(len(self.channels))]
    )

    self.names = ["p"+str(i+1) for i in range(6)]
Ejemplo n.º 14
0
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "u-{}".format(encoder_name)
        self.initialize()
Ejemplo n.º 15
0
    def __init__(self, encoder_name, classes=5, *args, **kwargs):
        '''
        Args:
            encoder_name: name of classification model (without last dense layers) used as feature
                extractor to build segmentation model.
            classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
            aux_params: if specified model will have additional classification auxiliary output
                build on top of encoder, supported params:
                    - classes (int): number of classes
                    - pooling (str): one of 'max', 'avg'. Default is 'avg'.
                    - dropout (float): dropout factor in [0, 1)
                    - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
        Returns:
            ``torch.nn.Module``: **Unet**
        Architecture:

        00-->01-->02-->03-->04
         \  /   /    /    /
          10-->11-->12-->13
           \  /   /    /
            20-->21---22
             \  /    /
              30---->31
               \   /
                 40
        '''
        super(UnetPP, self).__init__()
        self.encoder: EncoderMixin = get_encoder(encoder_name)

        zero, first, second, third, fourth = self.encoder.out_channels[1:]
        self.u01 = UpSample([zero + first, zero])
        self.u02 = UpSample([zero * 2 + first, zero])
        self.u03 = UpSample([zero * 3 + first, zero])
        self.u04 = UpSample([zero * 4 + first, zero])
        self.u11 = UpSample([first + second, first])
        self.u12 = UpSample([first * 2 + second, first])
        self.u13 = UpSample([first * 3 + second, first])
        self.u21 = UpSample([second + third, second])
        self.u22 = UpSample([second * 2 + third, second])
        self.u31 = UpSample([third + fourth, third])
        self.cl1 = SegmentHead(zero, classes)
        self.cl2 = SegmentHead(zero, classes)
        self.cl3 = SegmentHead(zero, classes)
        self.cl4 = SegmentHead(zero, classes)
        self.deepsupervise = nn.Conv2d(classes * 4, classes, 1)
        initialize_decoder(self)
Ejemplo n.º 16
0
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_weights: Optional[str] = "imagenet",
        encoder_output_stride: int = 16,
        decoder_channels: int = 32,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        upsampling: int = 4,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        if encoder_output_stride not in [16, 32]:
            raise ValueError("PAN support output stride 16 or 32, got {}".format(encoder_output_stride))

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=5,
            weights=encoder_weights,
            output_stride=encoder_output_stride,
        )

        self.decoder = PANDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=3,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "pan-{}".format(encoder_name)
        self.initialize()
Ejemplo n.º 17
0
    def __init__(
        self,
        encoder_name: str = "timm-efficientnet-b5",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        squeeze_ratio: int = 1,
        expansion_ratio: int = 1,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()
        self.classes = classes
        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = EfficientUnetPlusPlusDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            squeeze_ratio=squeeze_ratio,
            expansion_ratio=expansion_ratio,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "EfficientUNet++-{}".format(encoder_name)
        self.initialize()
Ejemplo n.º 18
0
    def __init__(self, lam=0.5, dense_head='5x5', combination=None):
        super().__init__()

        self.Lambda = lam

        if combination is not None:
            self.transformer = CombiTransform(combination)
        else:
            self.transformer = CombiTransform()

        self.encoder = get_encoder('resnet18', 3, 5, None)
        #         self.encoder = get_encoder('resnet18', 3, 5, 'imagenet')
        self.encoder.layer4 = torch.nn.Identity()

        self.project = ProjectionModule(dense_head)

        #         self.enc_proj = DenseContrastiveModule()

        self.dense_loss = DenseContrastiveLoss()
        self.glob_loss = GlobalContrastiveLoss()
Ejemplo n.º 19
0
 def __init__(
     self,
     encoder_name: str = "resnet34",
     encoder_depth: int = 5,
     encoder_weights: str = "imagenet",
     in_channels: int = 3,
     classes: int = 1,
     aux_params: Optional[dict] = None,
 ):
     super().__init__()
     self.encoder = get_encoder(
         encoder_name,
         in_channels=in_channels,
         depth=encoder_depth,
         weights=encoder_weights,
     )
     self.num_classes = classes
     self.classification_head = ClassificationHead(
         in_channels=self.encoder.out_channels[-1], **aux_params)
     self.name = "c-{}".format(encoder_name)
     init.initialize_head(self.classification_head)
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_channels: int = 256,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[str] = None,
        upsampling: int = 8,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            output_stride=8,
        )

        self.decoder = DeepLabV3Decoder(
            in_channels=self.encoder.out_channels[-1],
            out_channels=decoder_channels,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=self.decoder.out_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=1,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = LinknetDecoder(
            encoder_channels=self.encoder.out_channels,
            n_blocks=encoder_depth,
            prefinal_channels=32,
            use_batchnorm=decoder_use_batchnorm,
        )

        self.segmentation_head = SegmentationHead(in_channels=32,
                                                  out_channels=classes,
                                                  activation=activation,
                                                  kernel_size=1)

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "link-{}".format(encoder_name)
        self.initialize()
Ejemplo n.º 22
0
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: str = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        pretext_classes=-1,
        domain_classes=2,
        domain_layer=-2,
        domain_classifier='DomainClassifierFlatten',
        separate=False,
        input_shape=80,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        ) if separate is False else \
        get_separate_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )
        self.separate = separate

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if domain_classifier == 'DomainClassifier':
            domain_classifier = DomainClassifier
        elif domain_classifier == 'DomainClassifierFlatten':
            domain_classifier = DomainClassifierFlatten
        elif domain_classifier == 'DomainClassifierFlattenSimple':
            domain_classifier = DomainClassifierFlattenSimple
        elif domain_classifier == 'DomainClassifierFlattenCat':
            domain_classifier = DomainClassifierFlattenCat
        elif domain_classifier == 'DomainClassifierReduceFlatten':
            domain_classifier = DomainClassifierReduceFlatten

        self.domain_layer = domain_layer

        if type(self.domain_layer) == list:
            self.domain_classification_head = domain_classifier(
                in_channels=sum([
                    self.encoder.out_channels[idx_dl]
                    for idx_dl in self.domain_layer
                ]),
                domain_classes=domain_classes)
            if separate:
                self.encoder.set_domain_layer(self.domain_layer)
        else:
            self.domain_classification_head = domain_classifier(
                in_channels=self.encoder.out_channels[self.domain_layer],
                domain_classes=domain_classes,
                input_shape=input_shape,
            )

        if pretext_classes == -1:
            raise ValueError(f'initialize pretext_classes')
        self.pretext_classification_head = PretextClassifier(
            in_channels=self.encoder.out_channels[-1],
            pretext_classes=pretext_classes)

        self.name = "u-{}".format(encoder_name)
        self.initialize()
        return