Example #1
0
def test_unet2d_encode(num_patches: int,
                       num_channels: int,
                       num_output_channels: int,
                       is_downsampling: bool,
                       image_shape: TupleInt2) -> None:
    """
    Test if the Encode block of a Unet3D correctly works when passing in kernels that only operate in X and Y.
    """
    set_random_seed(1234)
    layer = UNet3D.UNetEncodeBlock((num_channels, num_output_channels),
                                   kernel_size=(1, 3, 3),
                                   downsampling_stride=(1, 2, 2) if is_downsampling else 1)
    input_shape = (num_patches, num_channels) + (1,) + image_shape
    input = torch.rand(*input_shape).float()
    output = layer(input)

    def output_image_size(input_image_size: int) -> int:
        # If max pooling is added, it is done with a kernel size of 2, shrinking the image by a factor of 2
        image_shrink_factor = 2 if is_downsampling else 1
        return input_image_size // image_shrink_factor

    # Expected output shape:
    # The first dimension (patches) should be retained unchanged.
    # We should get as many output channels as requested
    # Unet is defined as running over degenerate 3D images with Z=1, this should be preserved.
    # The two trailing dimensions are the adjusted image dimensions
    expected_output_shape = (num_patches, num_output_channels, 1,
                             output_image_size(image_shape[0]), output_image_size(image_shape[1]))
    assert output.shape == expected_output_shape
Example #2
0
 def create_encoder(self, channels: List[int]) -> ModuleList:
     """
     Create an image encoder network.
     """
     layers = []
     for i in range(len(channels) - 1):
         layers.append(
             UNet3D.UNetEncodeBlock(
                 channels=(channels[i], channels[i + 1]),
                 kernel_size=self.kernel_size_per_encoding_block[i],
                 downsampling_stride=self.stride_size_per_encoding_block[i],
                 padding_mode=self.padding_mode,
                 use_residual=False,
                 depth=i,
             ))
     return ModuleList(layers)