Ejemplo n.º 1
0
 def test_ill_shape(self, input_param):
     net = LocalNetDownSampleBlock(**input_param)
     input_shape = (1, input_param["in_channels"],
                    *([5] * input_param["spatial_dims"]))
     with self.assertRaises(ValueError):
         with eval_mode(net):
             net(torch.randn(input_shape))
Ejemplo n.º 2
0
 def test_ill_arg(self):
     # even kernel_size
     with self.assertRaises(NotImplementedError):
         LocalNetDownSampleBlock(spatial_dims=2,
                                 in_channels=2,
                                 out_channels=4,
                                 kernel_size=4)
Ejemplo n.º 3
0
 def test_shape(self, input_param):
     net = LocalNetDownSampleBlock(**input_param)
     input_shape = (1, input_param["in_channels"],
                    *([in_size] * input_param["spatial_dims"]))
     expect_mid_shape = (1, input_param["out_channels"],
                         *([in_size] * input_param["spatial_dims"]))
     expect_x_shape = (1, input_param["out_channels"],
                       *([in_size / 2] * input_param["spatial_dims"]))
     with eval_mode(net):
         x, mid = net(torch.randn(input_shape))
         self.assertEqual(x.shape, expect_x_shape)
         self.assertEqual(mid.shape, expect_mid_shape)
Ejemplo n.º 4
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        num_channel_initial: int,
        extract_levels: List[int],
        out_activation: Optional[Union[Tuple, str]],
        out_initializer: str = "kaiming_uniform",
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            num_channel_initial: number of initial channels.
            extract_levels: number of extraction levels.
            out_activation: activation to use at end layer.
            out_initializer: initializer for extraction layers.
        """
        super(LocalNet, self).__init__()
        self.extract_levels = extract_levels
        self.extract_max_level = max(self.extract_levels)  # E
        self.extract_min_level = min(self.extract_levels)  # D

        num_channels = [
            num_channel_initial * (2**level)
            for level in range(self.extract_max_level + 1)
        ]  # level 0 to E

        self.downsample_blocks = nn.ModuleList([
            LocalNetDownSampleBlock(
                spatial_dims=spatial_dims,
                in_channels=in_channels if i == 0 else num_channels[i - 1],
                out_channels=num_channels[i],
                kernel_size=7 if i == 0 else 3,
            ) for i in range(self.extract_max_level)
        ])  # level 0 to self.extract_max_level - 1
        self.conv3d_block = get_conv_block(
            spatial_dims=spatial_dims,
            in_channels=num_channels[-2],
            out_channels=num_channels[-1])  # self.extract_max_level

        self.upsample_blocks = nn.ModuleList([
            LocalNetUpSampleBlock(
                spatial_dims=spatial_dims,
                in_channels=num_channels[level + 1],
                out_channels=num_channels[level],
            ) for level in range(self.extract_max_level -
                                 1, self.extract_min_level - 1, -1)
        ])  # self.extract_max_level - 1 to self.extract_min_level

        self.extract_layers = nn.ModuleList([
            # if kernels are not initialized by zeros, with init NN, extract may be too large
            LocalNetFeatureExtractorBlock(
                spatial_dims=spatial_dims,
                in_channels=num_channels[level],
                out_channels=out_channels,
                act=out_activation,
                initializer=out_initializer,
            ) for level in self.extract_levels
        ])