Beispiel #1
0
 def __init__(
     self,
     spatial_dims,
     in_channels,
     out_channels,
     init_features=64,
     growth_rate=32,
     block_config=(6, 12, 24, 16),
     bn_size=4,
     dropout_prob=0.0,
     aspp_conv_out_channels=5,
 ):
     # initialise normal densenet
     super().__init__(
         spatial_dims,
         in_channels,
         out_channels,
         init_features,
         growth_rate,
         block_config,
         bn_size,
         dropout_prob,
     )
     # create aspp module
     aspp_in_features = self.class_layers[-1].in_features
     self.aspp = SimpleASPP(spatial_dims=spatial_dims,
                            in_channels=aspp_in_features,
                            conv_out_channels=aspp_conv_out_channels)
     # replace last linear component with updated number of input channels
     aspp_out_features = self.aspp.conv_k1.out_channels
     lin_out_channels = self.class_layers[-1].out_features
     self.class_layers[-1] = nn.Linear(aspp_out_features, lin_out_channels)
Beispiel #2
0
 def test_ill_args(self, input_param, input_data, error_type):
     with self.assertRaises(error_type):
         SimpleASPP(**input_param)
Beispiel #3
0
 def test_shape(self, input_param, input_data, expected_shape):
     net = SimpleASPP(**input_param)
     net.eval()
     with torch.no_grad():
         result = net(input_data)
         self.assertEqual(result.shape, expected_shape)
Beispiel #4
0
    def __init__(
        self,
        # dimensions: int = 3,
        spatial_dims: int = 2,
        in_channels: int = 1,
        out_channels: int = 2,
        features=(32, 64, 128, 256, 512),
        # act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        # norm: Union[str, tuple] = ("instance", {"affine": True}),
        # dropout: Union[float, tuple] = 0.0,
        dropout=(0.0, 0.0, 0.3, 0.4, 0.5),
        bilinear: bool = True,
        # upsample: str = "deconv",
    ):
        """
        A UNet implementation with 1D/2D/3D supports.

        Based on:

            Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and
            Morphometry". Nature Methods 16, 67–70 (2019), DOI:
            http://dx.doi.org/10.1038/s41592-018-0261-2

        Args:
            dimensions: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        Examples::

            # for spatial 2D
            >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with group norm
            >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32))

        See Also

            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

        """
        super().__init__()

        ft_chns = ensure_tuple_rep(features, 5)
        # print(f"BasicUNet features: {fea}.")

        f0_half = int(ft_chns[0] / 2)
        f1_half = int(ft_chns[1] / 2)
        f2_half = int(ft_chns[2] / 2)
        f3_half = int(ft_chns[3] / 2)

        self.in_conv = ConvBNActBlock(in_channels, ft_chns[0], dropout[0],
                                      spatial_dims)
        self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1],
                               spatial_dims)
        self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2],
                               spatial_dims)
        self.down3 = DownBlock(ft_chns[2], ft_chns[3], dropout[3],
                               spatial_dims)
        self.down4 = DownBlock(ft_chns[3], ft_chns[4], dropout[4],
                               spatial_dims)

        self.bridge0 = Convolution(spatial_dims,
                                   ft_chns[0],
                                   f0_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)
        self.bridge1 = Convolution(spatial_dims,
                                   ft_chns[1],
                                   f1_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)
        self.bridge2 = Convolution(spatial_dims,
                                   ft_chns[2],
                                   f2_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)
        self.bridge3 = Convolution(spatial_dims,
                                   ft_chns[3],
                                   f3_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)

        self.up1 = UpBlock(ft_chns[4], f3_half, ft_chns[3], bilinear,
                           dropout[3], spatial_dims)
        self.up2 = UpBlock(ft_chns[3], f2_half, ft_chns[2], bilinear,
                           dropout[2], spatial_dims)
        self.up3 = UpBlock(ft_chns[2], f1_half, ft_chns[1], bilinear,
                           dropout[1], spatial_dims)
        self.up4 = UpBlock(ft_chns[1], f0_half, ft_chns[0], bilinear,
                           dropout[0], spatial_dims)

        self.aspp = SimpleASPP(spatial_dims,
                               ft_chns[4],
                               int(ft_chns[4] / 4),
                               kernel_sizes=[1, 3, 3, 3],
                               dilations=[1, 2, 4, 6])

        self.out_conv = Convolution(spatial_dims,
                                    ft_chns[0],
                                    out_channels,
                                    conv_only=True)
Beispiel #5
0
    def __init__(
            self,
            spatial_dims: int = 2,
            in_channels: int = 1,
            out_channels: int = 2,
            feature_channels=(32, 64, 128, 256, 512),
            dropout=(0.0, 0.0, 0.3, 0.4, 0.5),
            bilinear: bool = True,
    ):
        """
        Args:
            spatial_dims: dimension of the operators. Defaults to 2, i.e., using 2D operators
                for all operators, for example, using Conv2D for all the convolutions.
                It should be 2 for 3D images
            in_channels: number of channels of the input image. Defaults to 1.
            out_channels: number of segmentation classes (2 for foreground/background segmentation).
                Defaults to 2.
            feature_channels: number of intermediate feature channels
                (must have 5 elements corresponding to five conv. stages).
                Defaults to (32, 64, 128, 256, 512).
            dropout: a sequence of 5 dropout ratios. Defaults to (0.0, 0.0, 0.3, 0.4, 0.5).
            bilinear: whether to use bilinear upsampling. Defaults to True.
        """
        super().__init__()
        ft_chns = ensure_tuple_rep(feature_channels, 5)

        f0_half = int(ft_chns[0] / 2)
        f1_half = int(ft_chns[1] / 2)
        f2_half = int(ft_chns[2] / 2)
        f3_half = int(ft_chns[3] / 2)

        self.in_conv = ConvBNActBlock(in_channels, ft_chns[0], dropout[0],
                                      spatial_dims)
        self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1],
                               spatial_dims)
        self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2],
                               spatial_dims)
        self.down3 = DownBlock(ft_chns[2], ft_chns[3], dropout[3],
                               spatial_dims)
        self.down4 = DownBlock(ft_chns[3], ft_chns[4], dropout[4],
                               spatial_dims)

        self.bridge0 = Convolution(spatial_dims,
                                   ft_chns[0],
                                   f0_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)
        self.bridge1 = Convolution(spatial_dims,
                                   ft_chns[1],
                                   f1_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)
        self.bridge2 = Convolution(spatial_dims,
                                   ft_chns[2],
                                   f2_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)
        self.bridge3 = Convolution(spatial_dims,
                                   ft_chns[3],
                                   f3_half,
                                   kernel_size=1,
                                   norm=Norm.BATCH,
                                   act=Act.LEAKYRELU)

        self.up1 = UpBlock(ft_chns[4], f3_half, ft_chns[3], bilinear,
                           dropout[3], spatial_dims)
        self.up2 = UpBlock(ft_chns[3], f2_half, ft_chns[2], bilinear,
                           dropout[2], spatial_dims)
        self.up3 = UpBlock(ft_chns[2], f1_half, ft_chns[1], bilinear,
                           dropout[1], spatial_dims)
        self.up4 = UpBlock(ft_chns[1], f0_half, ft_chns[0], bilinear,
                           dropout[0], spatial_dims)

        self.aspp = SimpleASPP(spatial_dims,
                               ft_chns[4],
                               int(ft_chns[4] / 4),
                               kernel_sizes=[1, 3, 3, 3],
                               dilations=[1, 2, 4, 6])

        self.out_conv = Convolution(spatial_dims,
                                    ft_chns[0],
                                    out_channels,
                                    conv_only=True)
Beispiel #6
0
 def test_shape(self, input_param, input_shape, expected_shape):
     net = SimpleASPP(**input_param)
     with eval_mode(net):
         result = net(torch.randn(input_shape))
         self.assertEqual(result.shape, expected_shape)