示例#1
0
 def test_stride1(self):
     for strides in [2, [2, 2], (2, 2)]:
         conv = ResidualUnit(2, 1, self.output_channels, strides=strides)
         out = conv(self.imt)
         expected_shape = (1, self.output_channels, self.im_shape[0] // 2,
                           self.im_shape[1] // 2)
         self.assertEqual(out.shape, expected_shape)
 def _get_encode_layer(self, in_channels: int, out_channels: int,
                       strides: int, is_last: bool) -> nn.Module:
     """
     Returns a single layer of the encoder part of the network.
     """
     mod: nn.Module
     if self.num_res_units > 0:
         mod = ResidualUnit(
             spatial_dims=self.dimensions,
             in_channels=in_channels,
             out_channels=out_channels,
             strides=strides,
             kernel_size=self.kernel_size,
             subunits=self.num_res_units,
             act=self.act,
             norm=self.norm,
             dropout=self.dropout,
             bias=self.bias,
             last_conv_only=is_last,
         )
     mod = Convolution(
         spatial_dims=self.dimensions,
         in_channels=in_channels,
         out_channels=out_channels,
         strides=strides,
         kernel_size=self.kernel_size,
         act=self.act,
         norm=self.norm,
         dropout=self.dropout,
         bias=self.bias,
         conv_only=is_last,
     )
     return mod
示例#3
0
    def _get_encode_layer(self, in_channels: int, out_channels: int,
                          strides: int, is_last: bool) -> nn.Module:

        if self.num_res_units > 0:
            return ResidualUnit(
                dimensions=self.dimensions,
                in_channels=in_channels,
                out_channels=out_channels,
                strides=strides,
                kernel_size=self.kernel_size,
                subunits=self.num_res_units,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                last_conv_only=is_last,
            )
        else:
            return Convolution(
                dimensions=self.dimensions,
                in_channels=in_channels,
                out_channels=out_channels,
                strides=strides,
                kernel_size=self.kernel_size,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                conv_only=is_last,
            )
示例#4
0
    def _get_layer(self, in_channels, out_channels, strides, is_last):
        """
        Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels`
        number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last`
        is True this is the final layer and is not expected to include activation and normalization layers.
        """
        common_kwargs = dict(
            dimensions=self.dimensions,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            act=self.act,
            norm=self.norm,
            dropout=self.dropout,
            bias=self.bias,
        )

        layer = Convolution(
            in_channels=in_channels,
            strides=strides,
            is_transposed=True,
            conv_only=is_last or self.num_res_units > 0,
            **common_kwargs,
        )

        if self.num_res_units > 0:
            ru = ResidualUnit(in_channels=out_channels,
                              subunits=self.num_res_units,
                              last_conv_only=is_last,
                              **common_kwargs)

            layer = nn.Sequential(layer, ru)

        return layer
    def _get_intermediate_module(
            self, in_channels: int,
            num_inter_units: int) -> Tuple[nn.Module, int]:
        """
        Returns the intermediate block of the network which accepts input from the encoder and whose output goes
        to the decoder.
        """
        # Define some types
        intermediate: nn.Module
        unit: nn.Module

        intermediate = nn.Identity()
        layer_channels = in_channels

        if self.inter_channels:
            intermediate = nn.Sequential()

            for i, (dc, di) in enumerate(
                    zip(self.inter_channels, self.inter_dilations)):
                if self.num_inter_units > 0:
                    unit = ResidualUnit(
                        spatial_dims=self.dimensions,
                        in_channels=layer_channels,
                        out_channels=dc,
                        strides=1,
                        kernel_size=self.kernel_size,
                        subunits=self.num_inter_units,
                        act=self.act,
                        norm=self.norm,
                        dropout=self.dropout,
                        dilation=di,
                        bias=self.bias,
                    )
                else:
                    unit = Convolution(
                        spatial_dims=self.dimensions,
                        in_channels=layer_channels,
                        out_channels=dc,
                        strides=1,
                        kernel_size=self.kernel_size,
                        act=self.act,
                        norm=self.norm,
                        dropout=self.dropout,
                        dilation=di,
                        bias=self.bias,
                    )

                intermediate.add_module("inter_%i" % i, unit)
                layer_channels = dc

        return intermediate, layer_channels
示例#6
0
    def _get_intermediate_module(
            self, in_channels: int,
            num_inter_units: int) -> Tuple[nn.Module, int]:
        # Define some types
        intermediate: nn.Module
        unit: nn.Module

        intermediate = nn.Identity()
        layer_channels = in_channels

        if self.inter_channels:
            intermediate = nn.Sequential()

            for i, (dc, di) in enumerate(
                    zip(self.inter_channels, self.inter_dilations)):
                if self.num_inter_units > 0:
                    unit = ResidualUnit(
                        dimensions=self.dimensions,
                        in_channels=layer_channels,
                        out_channels=dc,
                        strides=1,
                        kernel_size=self.kernel_size,
                        subunits=self.num_inter_units,
                        act=self.act,
                        norm=self.norm,
                        dropout=self.dropout,
                        dilation=di,
                    )
                else:
                    unit = Convolution(
                        dimensions=self.dimensions,
                        in_channels=layer_channels,
                        out_channels=dc,
                        strides=1,
                        kernel_size=self.kernel_size,
                        act=self.act,
                        norm=self.norm,
                        dropout=self.dropout,
                        dilation=di,
                    )

                intermediate.add_module("inter_%i" % i, unit)
                layer_channels = dc

        return intermediate, layer_channels
    def _get_decode_layer(self, in_channels: int, out_channels: int,
                          strides: int, is_last: bool) -> nn.Sequential:
        """
        Returns a single layer of the decoder part of the network.
        """
        decode = nn.Sequential()

        conv = Convolution(
            spatial_dims=self.dimensions,
            in_channels=in_channels,
            out_channels=out_channels,
            strides=strides,
            kernel_size=self.up_kernel_size,
            act=self.act,
            norm=self.norm,
            dropout=self.dropout,
            bias=self.bias,
            conv_only=is_last and self.num_res_units == 0,
            is_transposed=True,
        )

        decode.add_module("conv", conv)

        if self.num_res_units > 0:
            ru = ResidualUnit(
                spatial_dims=self.dimensions,
                in_channels=out_channels,
                out_channels=out_channels,
                strides=1,
                kernel_size=self.kernel_size,
                subunits=1,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                bias=self.bias,
                last_conv_only=is_last,
            )

            decode.add_module("resunit", ru)

        return decode
示例#8
0
    def _get_layer(self, in_channels: int, out_channels: int, strides: int,
                   is_last: bool):
        """
        Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels`
        number of channels. The `strides` indicates downsampling factor, ie. convolutional stride. If `is_last`
        is True this is the final layer and is not expected to include activation and normalization layers.
        """

        layer: Union[ResidualUnit, Convolution]

        if self.num_res_units > 0:
            layer = ResidualUnit(
                subunits=self.num_res_units,
                last_conv_only=is_last,
                dimensions=self.dimensions,
                in_channels=in_channels,
                out_channels=out_channels,
                strides=strides,
                kernel_size=self.kernel_size,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                bias=self.bias,
            )
        else:
            layer = Convolution(
                conv_only=is_last,
                dimensions=self.dimensions,
                in_channels=in_channels,
                out_channels=out_channels,
                strides=strides,
                kernel_size=self.kernel_size,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                bias=self.bias,
            )

        return layer
示例#9
0
 def test_conv_only1(self):
     conv = ResidualUnit(2, 1, self.output_channels)
     out = conv(self.imt)
     expected_shape = (1, self.output_channels, self.im_shape[0],
                       self.im_shape[1])
     self.assertEqual(out.shape, expected_shape)