Exemplo n.º 1
0
 def __init__(self, n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer, residualConnections = False):
     super(unet, self).__init__(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer)
     self.ins = in_conv(self.n_channels, base_filters, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.ds_0 = DownsamplingModule(base_filters, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm)
     self.en_1 = EncodingModule(base_filters*2, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm)
     self.en_2 = EncodingModule(base_filters*4, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm)
     self.en_3 = EncodingModule(base_filters*8, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16, self.Conv, self.Dropout, self.InstanceNorm)
     self.en_4 = EncodingModule(base_filters*16, base_filters*16, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.us_3 = UpsamplingModule(base_filters*16, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm)
     self.de_3 = DecodingModule(base_filters*16, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.us_2 = UpsamplingModule(base_filters*8, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm)
     self.de_2 = DecodingModule(base_filters*8, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.us_1 = UpsamplingModule(base_filters*4, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm)
     self.de_1 = DecodingModule(base_filters*4, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections)
     self.us_0 = UpsamplingModule(base_filters*2, base_filters, self.Conv, self.Dropout, self.InstanceNorm)
     self.out = out_conv(base_filters*2, n_classes, self.Conv, self.Dropout, self.InstanceNorm,
                         final_convolution_layer=self.final_convolution_layer, res=residualConnections)
Exemplo n.º 2
0
    def __init__(
        self,
        parameters: dict,
        residualConnections=False,
    ):
        self.network_kwargs = {"res": residualConnections}
        super(unet, self).__init__(parameters)

        if not (checkPatchDivisibility(parameters["patch_size"])):
            sys.exit(
                "The patch size is not divisible by 16, which is required for",
                parameters["model"]["architecture"],
            )

        self.ins = in_conv(
            input_channels=self.n_channels,
            output_channels=self.base_filters,
            conv=self.Conv,
            dropout=self.Dropout,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.ds_0 = DownsamplingModule(
            input_channels=self.base_filters,
            output_channels=self.base_filters * 2,
            conv=self.Conv,
            norm=self.Norm,
        )
        self.en_1 = EncodingModule(
            input_channels=self.base_filters * 2,
            output_channels=self.base_filters * 2,
            conv=self.Conv,
            dropout=self.Dropout,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.ds_1 = DownsamplingModule(
            input_channels=self.base_filters * 2,
            output_channels=self.base_filters * 4,
            conv=self.Conv,
            norm=self.Norm,
        )
        self.en_2 = EncodingModule(
            input_channels=self.base_filters * 4,
            output_channels=self.base_filters * 4,
            conv=self.Conv,
            dropout=self.Dropout,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.ds_2 = DownsamplingModule(
            input_channels=self.base_filters * 4,
            output_channels=self.base_filters * 8,
            conv=self.Conv,
            norm=self.Norm,
        )
        self.en_3 = EncodingModule(
            input_channels=self.base_filters * 8,
            output_channels=self.base_filters * 8,
            conv=self.Conv,
            dropout=self.Dropout,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.ds_3 = DownsamplingModule(
            input_channels=self.base_filters * 8,
            output_channels=self.base_filters * 16,
            conv=self.Conv,
            norm=self.Norm,
        )
        self.en_4 = EncodingModule(
            input_channels=self.base_filters * 16,
            output_channels=self.base_filters * 16,
            conv=self.Conv,
            dropout=self.Dropout,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.us_3 = UpsamplingModule(
            input_channels=self.base_filters * 16,
            output_channels=self.base_filters * 8,
            conv=self.Conv,
            interpolation_mode=self.linear_interpolation_mode,
        )
        self.de_3 = DecodingModule(
            input_channels=self.base_filters * 16,
            output_channels=self.base_filters * 8,
            conv=self.Conv,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.us_2 = UpsamplingModule(
            input_channels=self.base_filters * 8,
            output_channels=self.base_filters * 4,
            conv=self.Conv,
            interpolation_mode=self.linear_interpolation_mode,
        )
        self.de_2 = DecodingModule(
            input_channels=self.base_filters * 8,
            output_channels=self.base_filters * 4,
            conv=self.Conv,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.us_1 = UpsamplingModule(
            input_channels=self.base_filters * 4,
            output_channels=self.base_filters * 2,
            conv=self.Conv,
            interpolation_mode=self.linear_interpolation_mode,
        )
        self.de_1 = DecodingModule(
            input_channels=self.base_filters * 4,
            output_channels=self.base_filters * 2,
            conv=self.Conv,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.us_0 = UpsamplingModule(
            input_channels=self.base_filters * 2,
            output_channels=self.base_filters,
            conv=self.Conv,
            interpolation_mode=self.linear_interpolation_mode,
        )
        self.de_0 = DecodingModule(
            input_channels=self.base_filters * 2,
            output_channels=self.base_filters * 2,
            conv=self.Conv,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )
        self.out = out_conv(
            input_channels=self.base_filters * 2,
            output_channels=self.n_classes,
            conv=self.Conv,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
            final_convolution_layer=self.final_convolution_layer,
            sigmoid_input_multiplier=self.sigmoid_input_multiplier,
        )
Exemplo n.º 3
0
    def __init__(
        self,
        parameters: dict,
        residualConnections=False,
    ):
        self.network_kwargs = {"res": residualConnections}
        super(light_unet_multilayer, self).__init__(parameters)

        # self.network_kwargs = {"res": False}

        if not ("depth" in parameters["model"]):
            parameters["model"]["depth"] = 4
            print("Default depth set to 4.")

        patch_check = checkPatchDimensions(parameters["patch_size"],
                                           numlay=parameters["model"]["depth"])

        if patch_check != parameters["model"]["depth"] and patch_check >= 2:
            print(
                "The patch size is not large enough for desired depth. It is expected that each dimension of the patch size is divisible by 2^i, where i is in a integer greater than or equal to 2. Only the first %d layers will run."
                % patch_check)
        elif patch_check < 2:
            sys.exit(
                "The patch size is not large enough for desired depth. It is expected that each dimension of the patch size is divisible by 2^i, where i is in a integer greater than or equal to 2."
            )

        self.num_layers = patch_check

        self.ins = in_conv(
            input_channels=self.n_channels,
            output_channels=self.base_filters,
            conv=self.Conv,
            dropout=self.Dropout,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
        )

        self.ds = ModuleList([])
        self.en = ModuleList([])
        self.us = ModuleList([])
        self.de = ModuleList([])

        for _ in range(0, self.num_layers):
            self.ds.append(
                DownsamplingModule(
                    input_channels=self.base_filters,
                    output_channels=self.base_filters,
                    conv=self.Conv,
                    norm=self.Norm,
                ))

            self.us.append(
                UpsamplingModule(
                    input_channels=self.base_filters,
                    output_channels=self.base_filters,
                    conv=self.Conv,
                    interpolation_mode=self.linear_interpolation_mode,
                ))

            self.de.append(
                DecodingModule(
                    input_channels=self.base_filters * 2,
                    output_channels=self.base_filters,
                    conv=self.Conv,
                    norm=self.Norm,
                    network_kwargs=self.network_kwargs,
                ))

            self.en.append(
                EncodingModule(
                    input_channels=self.base_filters,
                    output_channels=self.base_filters,
                    conv=self.Conv,
                    dropout=self.Dropout,
                    norm=self.Norm,
                    network_kwargs=self.network_kwargs,
                ))

        self.out = out_conv(
            input_channels=self.base_filters,
            output_channels=self.n_classes,
            conv=self.Conv,
            norm=self.Norm,
            network_kwargs=self.network_kwargs,
            final_convolution_layer=self.final_convolution_layer,
            sigmoid_input_multiplier=self.sigmoid_input_multiplier,
        )
Exemplo n.º 4
0
    def __init__(
        self,
        parameters: dict,
    ):
        super(unetr, self).__init__(parameters)

        if not ("inner_patch_size" in parameters["model"]):
            parameters["model"]["inner_patch_size"] = parameters["patch_size"][0]
            print("Default inner patch size set to %d." % parameters["patch_size"][0])

        if "inner_patch_size" in parameters["model"]:
            if np.ceil(np.log2(parameters["model"]["inner_patch_size"])) != np.floor(
                np.log2(parameters["model"]["inner_patch_size"])
            ):
                sys.exit("The inner patch size must be a power of 2.")

        self.patch_size = parameters["model"]["inner_patch_size"]
        self.depth = int(np.log2(self.patch_size))
        patch_check = checkPatchDimensions(parameters["patch_size"], self.depth)

        if patch_check != self.depth and patch_check >= 2:
            print(
                "The image size is not large enough for desired depth. It is expected that each dimension of the image is divisible by 2^i, where i is in a integer greater than or equal to 2. Only the first %d layers will run."
                % patch_check
            )
        elif patch_check < 2:
            sys.exit(
                "The image size is not large enough for desired depth. It is expected that each dimension of the image is divisible by 2^i, where i is in a integer greater than or equal to 2."
            )

        if not ("num_heads" in parameters["model"]):
            parameters["model"]["num_heads"] = 12
            print(
                "Default number of heads in multi-head self-attention (MSA) set to 12."
            )

        if not ("embed_dim" in parameters["model"]):
            parameters["model"]["embed_dim"] = 768
            print("Default size of embedded dimension set to 768.")

        if self.n_dimensions == 2:
            self.img_size = parameters["patch_size"][0:2]
        elif self.n_dimensions == 3:
            self.img_size = parameters["patch_size"]

        self.num_layers = 3 * self.depth  # number of transformer layers
        self.out_layers = np.arange(2, self.num_layers, 3)
        self.num_heads = parameters["model"]["num_heads"]
        self.embed_size = parameters["model"]["embed_dim"]
        self.patch_dim = [i // self.patch_size for i in self.img_size]

        if not all([i % self.patch_size == 0 for i in self.img_size]):
            sys.exit(
                "The image size is not divisible by the patch size in at least 1 dimension. UNETR is not defined in this case."
            )
        if not all([self.patch_size <= i for i in self.img_size]):
            sys.exit("The inner patch size must be smaller than the input image.")

        self.transformer = _Transformer(
            img_size=self.img_size,
            patch_size=self.patch_size,
            in_feats=self.n_channels,
            embed_size=self.embed_size,
            num_heads=self.num_heads,
            mlp_dim=2048,
            num_layers=self.num_layers,
            out_layers=self.out_layers,
            Conv=self.Conv,
            Norm=self.Norm,
        )

        self.upsampling = ModuleList([])
        self.convs = ModuleList([])

        for i in range(0, self.depth - 1):
            # add deconv blocks
            tempconvs = nn.Sequential()
            tempconvs.add_module(
                "conv0",
                _DeconvConvBlock(
                    self.embed_size,
                    32 * 2**self.depth,
                    self.Norm,
                    self.Conv,
                    self.ConvTranspose,
                ),
            )

            for j in range(self.depth - 2, i, -1):
                tempconvs.add_module(
                    "conv%d" % j,
                    _DeconvConvBlock(
                        128 * 2**j,
                        128 * 2 ** (j - 1),
                        self.Norm,
                        self.Conv,
                        self.ConvTranspose,
                    ),
                )

            self.convs.append(tempconvs)

            # add upsampling
            self.upsampling.append(
                _UpsampleBlock(
                    128 * 2 ** (i + 1), self.Norm, self.Conv, self.ConvTranspose
                )
            )

        # add upsampling for transformer output (no convs)
        self.upsampling.append(
            self.ConvTranspose(
                in_channels=self.embed_size,
                out_channels=32 * 2**self.depth,
                kernel_size=2,
                stride=2,
                padding=0,
                output_padding=0,
            )
        )

        self.input_conv = nn.Sequential()
        self.input_conv.add_module(
            "conv1", _ConvBlock(self.n_channels, 32, self.Norm, self.Conv)
        )
        self.input_conv.add_module("conv2", _ConvBlock(32, 64, self.Norm, self.Conv))

        self.output_conv = nn.Sequential()
        self.output_conv.add_module("conv1", _ConvBlock(128, 64, self.Norm, self.Conv))
        self.output_conv.add_module("conv2", _ConvBlock(64, 64, self.Norm, self.Conv))
        self.output_conv.add_module(
            "conv3",
            out_conv(
                64,
                self.n_classes,
                conv_kwargs={
                    "kernel_size": 1,
                    "stride": 1,
                    "padding": 0,
                    "bias": False,
                },
                norm=self.Norm,
                conv=self.Conv,
                final_convolution_layer=self.final_convolution_layer,
                sigmoid_input_multiplier=self.sigmoid_input_multiplier,
            ),
        )