def __init__( self, image_size: tuple, out_channels: int, num_channel_initial: int, depth: int, out_kernel_initializer: str, out_activation: str, pooling: bool = True, concat_skip: bool = False, name: str = "Unet", **kwargs, ): """ Initialise UNet. :param image_size: (dim1, dim2, dim3), dims of input image. :param out_channels: number of channels for the output :param num_channel_initial: number of initial channels :param depth: input is at level 0, bottom is at level depth :param out_kernel_initializer: kernel initializer for the last layer :param out_activation: activation at the last layer :param pooling: for downsampling, use non-parameterized pooling if true, otherwise use conv3d :param concat_skip: when upsampling, concatenate skipped tensor if true, otherwise use addition :param name: name of the backbone. :param kwargs: additional arguments. """ super().__init__( image_size=image_size, out_channels=out_channels, num_channel_initial=num_channel_initial, out_kernel_initializer=out_kernel_initializer, out_activation=out_activation, name=name, **kwargs, ) # init layer variables num_channels = [num_channel_initial * (2**d) for d in range(depth + 1)] self._num_channel_initial = num_channel_initial self._depth = depth self._downsample_blocks = [ layer.DownSampleResnetBlock(filters=num_channels[d], pooling=pooling) for d in range(depth) ] self._bottom_conv3d = layer.Conv3dBlock(filters=num_channels[depth]) self._bottom_res3d = layer.Residual3dBlock(filters=num_channels[depth]) self._upsample_blocks = [ layer.UpSampleResnetBlock(filters=num_channels[d], concat=concat_skip) for d in range(depth) ] self._output_conv3d = layer.Conv3dWithResize( output_shape=image_size, filters=out_channels, kernel_initializer=out_kernel_initializer, activation=out_activation, )
def __init__(self, image_size, out_channels, num_channel_initial, depth, out_kernel_initializer, out_activation, pooling=True, concat_skip=False, **kwargs): """ :param image_size: [f_dim1, f_dim2, f_dim3] :param out_channels: number of channels for the output :param num_channel_initial: :param depth: input is at level 0, bottom is at level depth :param out_kernel_initializer: :param out_activation: :param pooling: true if use pooling to down sample :param kwargs: """ super(UNet, self).__init__(**kwargs) # init layer variables nc = [num_channel_initial * (2 ** d) for d in range(depth + 1)] self._num_channel_initial = num_channel_initial self._depth = depth self._downsample_blocks = [layer.DownSampleResnetBlock(filters=nc[d], pooling=pooling) for d in range(depth)] self._bottom_conv3d = layer.Conv3dBlock(filters=nc[depth]) self._bottom_res3d = layer.Residual3dBlock(filters=nc[depth]) self._upsample_blocks = [layer.UpSampleResnetBlock(filters=nc[d], concat=concat_skip) for d in range(depth)] self._output_conv3d = layer.Conv3dWithResize(output_shape=image_size, filters=out_channels, kernel_initializer=out_kernel_initializer, activation=out_activation)
def __init__( self, image_size, out_channels, num_channel_initial, depth, out_kernel_initializer, out_activation, pooling=True, concat_skip=False, **kwargs, ): """ Initialise UNet. :param image_size: list, [f_dim1, f_dim2, f_dim3], dims of input image. :param out_channels: int, number of channels for the output :param num_channel_initial: int, number of initial channels :param depth: int, input is at level 0, bottom is at level depth :param out_kernel_initializer: str, which kernel to use as initialiser :param out_activation: str, activation at last layer :param pooling: Boolean, for downsampling, use non-parameterized pooling if true, otherwise use conv3d :param concat_skip: Boolean, when upsampling, concatenate skipped tensor if true, otherwise use addition :param kwargs: """ super(UNet, self).__init__(**kwargs) # init layer variables num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)] self._num_channel_initial = num_channel_initial self._depth = depth self._downsample_blocks = [ layer.DownSampleResnetBlock(filters=num_channels[d], pooling=pooling) for d in range(depth) ] self._bottom_conv3d = layer.Conv3dBlock(filters=num_channels[depth]) self._bottom_res3d = layer.Residual3dBlock(filters=num_channels[depth]) self._upsample_blocks = [ layer.UpSampleResnetBlock(filters=num_channels[d], concat=concat_skip) for d in range(depth) ] self._output_conv3d = layer.Conv3dWithResize( output_shape=image_size, filters=out_channels, kernel_initializer=out_kernel_initializer, activation=out_activation, )
def test_init_UNet(): """ Testing init of UNet as expected """ local_test = u.UNet( image_size=[1, 2, 3], out_channels=3, num_channel_initial=3, depth=5, out_kernel_initializer="he_normal", out_activation="softmax", ) # Asserting num channels initial is the same, Pass assert local_test._num_channel_initial == 3 # Asserting depth is the same, Pass assert local_test._depth == 5 # Assert downsample blocks type is correct, Pass assert all( isinstance(item, type(layer.DownSampleResnetBlock(12))) for item in local_test._downsample_blocks ) # Assert number of downsample blocks is correct (== depth), Pass assert len(local_test._downsample_blocks) == 5 # Assert bottom_conv3d type is correct, Pass assert isinstance(local_test._bottom_conv3d, type(layer.Conv3dBlock(5))) # Assert bottom res3d type is correct, Pass assert isinstance(local_test._bottom_res3d, type(layer.Residual3dBlock(5))) # Assert upsample blocks type is correct, Pass assert all( isinstance(item, type(layer.UpSampleResnetBlock(12))) for item in local_test._upsample_blocks ) # Assert number of upsample blocks is correct (== depth), Pass assert len(local_test._upsample_blocks) == 5 # Assert output_conv3d is correct type, Pass assert isinstance( local_test._output_conv3d, type(layer.Conv3dWithResize([1, 2, 3], 3)) )
def test_upsample_resnet_block(): """ Test the layer.UpSampleResnetBlock class and its default attributes. """ batch_size = 5 channels = 4 input_size = (32, 32, 16) output_size = (64, 64, 32) input_tensor_size = (batch_size,) + input_size + (channels,) skip_tensor_size = (batch_size,) + output_size + (channels // 2,) model = layer.UpSampleResnetBlock(8) model.build([input_tensor_size, skip_tensor_size]) assert model._filters == 8 assert model._concat is False assert isinstance(model._conv3d_block, layer.Conv3dBlock) assert isinstance(model._residual_block, layer.Residual3dBlock) assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
def __init__( self, image_size: tuple, out_channels: int, num_channel_initial: int, depth: int, out_kernel_initializer: str, out_activation: str, pooling: bool = True, concat_skip: bool = False, control_points: (tuple, None) = None, **kwargs, ): """ Initialise UNet. :param image_size: tuple, (dim1, dim2, dim3), dims of input image. :param out_channels: int, number of channels for the output :param num_channel_initial: int, number of initial channels :param depth: int, input is at level 0, bottom is at level depth :param out_kernel_initializer: str, which kernel to use as initializer :param out_activation: str, activation at last layer :param pooling: Boolean, for downsampling, use non-parameterized pooling if true, otherwise use conv3d :param concat_skip: Boolean, when upsampling, concatenate skipped tensor if true, otherwise use addition :param control_points: (tuple, None), specify the distance between control points (in voxels). :param kwargs: """ super(UNet, self).__init__(**kwargs) # init layer variables num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)] self._num_channel_initial = num_channel_initial self._depth = depth self._downsample_blocks = [ layer.DownSampleResnetBlock(filters=num_channels[d], pooling=pooling) for d in range(depth) ] self._bottom_conv3d = layer.Conv3dBlock(filters=num_channels[depth]) self._bottom_res3d = layer.Residual3dBlock(filters=num_channels[depth]) self._upsample_blocks = [ layer.UpSampleResnetBlock(filters=num_channels[d], concat=concat_skip) for d in range(depth) ] self._output_conv3d = layer.Conv3dWithResize( output_shape=image_size, filters=out_channels, kernel_initializer=out_kernel_initializer, activation=out_activation, ) self.resize = ( layer.ResizeCPTransform(control_points) if control_points is not None else False ) self.interpolate = ( layer.BSplines3DTransform(control_points, image_size) if control_points is not None else False )