示例#1
0
 def test_shape(self, input_param, input_shape, expected_shape):
     for net in [
             UnetResBlock(**input_param),
             UnetBasicBlock(**input_param)
     ]:
         with eval_mode(net):
             result = net(torch.randn(input_shape))
             self.assertEqual(result.shape, expected_shape)
示例#2
0
 def test_shape(self, input_param, input_shape, expected_shape):
     for net in [
             UnetResBlock(**input_param),
             UnetBasicBlock(**input_param)
     ]:
         net.eval()
         with torch.no_grad():
             result = net(torch.randn(input_shape))
             self.assertEqual(result.shape, expected_shape)
示例#3
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[Sequence[int], int],
        upsample_kernel_size: Union[Sequence[int], int],
        norm_name: Union[Tuple, str],
        res_block: bool = False,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernel_size: convolution kernel size.
            upsample_kernel_size: convolution kernel size for transposed convolution layers.
            norm_name: feature normalization type and arguments.
            res_block: bool argument to determine if residual block is used.

        """

        super().__init__()
        upsample_stride = upsample_kernel_size
        self.transp_conv = get_conv_layer(
            spatial_dims,
            in_channels,
            out_channels,
            kernel_size=upsample_kernel_size,
            stride=upsample_stride,
            conv_only=True,
            is_transposed=True,
        )

        if res_block:
            self.conv_block = UnetResBlock(
                spatial_dims,
                out_channels + out_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=1,
                norm_name=norm_name,
            )
        else:
            self.conv_block = UnetBasicBlock(  # type: ignore
                spatial_dims,
                out_channels + out_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=1,
                norm_name=norm_name,
            )
示例#4
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[Sequence[int], int],
        stride: Union[Sequence[int], int],
        norm_name: Union[Tuple, str],
        res_block: bool = False,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernel_size: convolution kernel size.
            stride: convolution stride.
            norm_name: feature normalization type and arguments.
            res_block: bool argument to determine if residual block is used.

        """

        super().__init__()

        if res_block:
            self.layer = UnetResBlock(
                spatial_dims=spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                norm_name=norm_name,
            )
        else:
            self.layer = UnetBasicBlock(  # type: ignore
                spatial_dims=spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                norm_name=norm_name,
            )
示例#5
0
 def test_ill_arg(self):
     with self.assertRaises(ValueError):
         UnetBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name="norm")
     with self.assertRaises(AssertionError):
         UnetResBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch")
示例#6
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        num_layer: int,
        kernel_size: Union[Sequence[int], int],
        stride: Union[Sequence[int], int],
        upsample_kernel_size: Union[Sequence[int], int],
        norm_name: Union[Tuple, str],
        conv_block: bool = False,
        res_block: bool = False,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            num_layer: number of upsampling blocks.
            kernel_size: convolution kernel size.
            stride: convolution stride.
            upsample_kernel_size: convolution kernel size for transposed convolution layers.
            norm_name: feature normalization type and arguments.
            conv_block: bool argument to determine if convolutional block is used.
            res_block: bool argument to determine if residual block is used.

        """

        super().__init__()

        upsample_stride = upsample_kernel_size
        self.transp_conv_init = get_conv_layer(
            spatial_dims,
            in_channels,
            out_channels,
            kernel_size=upsample_kernel_size,
            stride=upsample_stride,
            conv_only=True,
            is_transposed=True,
        )
        if conv_block:
            if res_block:
                self.blocks = nn.ModuleList([
                    nn.Sequential(
                        get_conv_layer(
                            spatial_dims,
                            out_channels,
                            out_channels,
                            kernel_size=upsample_kernel_size,
                            stride=upsample_stride,
                            conv_only=True,
                            is_transposed=True,
                        ),
                        UnetResBlock(
                            spatial_dims=spatial_dims,
                            in_channels=out_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            norm_name=norm_name,
                        ),
                    ) for i in range(num_layer)
                ])
            else:
                self.blocks = nn.ModuleList([
                    nn.Sequential(
                        get_conv_layer(
                            spatial_dims,
                            out_channels,
                            out_channels,
                            kernel_size=upsample_kernel_size,
                            stride=upsample_stride,
                            conv_only=True,
                            is_transposed=True,
                        ),
                        UnetBasicBlock(
                            spatial_dims=spatial_dims,
                            in_channels=out_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            norm_name=norm_name,
                        ),
                    ) for i in range(num_layer)
                ])
        else:
            self.blocks = nn.ModuleList([
                get_conv_layer(
                    spatial_dims,
                    out_channels,
                    out_channels,
                    kernel_size=upsample_kernel_size,
                    stride=upsample_stride,
                    conv_only=True,
                    is_transposed=True,
                ) for i in range(num_layer)
            ])