예제 #1
0
    def __init__(self,
                 image_size, out_channels,
                 num_channel_initial, depth,
                 out_kernel_initializer, out_activation,
                 pooling=True, concat_skip=False, **kwargs):
        """
        :param image_size: [f_dim1, f_dim2, f_dim3]
        :param out_channels: number of channels for the output
        :param num_channel_initial:
        :param depth: input is at level 0, bottom is at level depth
        :param out_kernel_initializer:
        :param out_activation:
        :param pooling: true if use pooling to down sample
        :param kwargs:
        """
        super(UNet, self).__init__(**kwargs)

        # init layer variables
        nc = [num_channel_initial * (2 ** d) for d in range(depth + 1)]

        self._num_channel_initial = num_channel_initial
        self._depth = depth
        self._downsample_blocks = [layer.DownSampleResnetBlock(filters=nc[d], pooling=pooling)
                                   for d in range(depth)]
        self._bottom_conv3d = layer.Conv3dBlock(filters=nc[depth])
        self._bottom_res3d = layer.Residual3dBlock(filters=nc[depth])
        self._upsample_blocks = [layer.UpSampleResnetBlock(filters=nc[d], concat=concat_skip)
                                 for d in range(depth)]
        self._output_conv3d = layer.Conv3dWithResize(output_shape=image_size, filters=out_channels,
                                                     kernel_initializer=out_kernel_initializer,
                                                     activation=out_activation)
예제 #2
0
    def __init__(
        self,
        image_size: tuple,
        out_channels: int,
        num_channel_initial: int,
        depth: int,
        out_kernel_initializer: str,
        out_activation: str,
        pooling: bool = True,
        concat_skip: bool = False,
        name: str = "Unet",
        **kwargs,
    ):
        """
        Initialise UNet.

        :param image_size: (dim1, dim2, dim3), dims of input image.
        :param out_channels: number of channels for the output
        :param num_channel_initial: number of initial channels
        :param depth: input is at level 0, bottom is at level depth
        :param out_kernel_initializer: kernel initializer for the last layer
        :param out_activation: activation at the last layer
        :param pooling: for downsampling, use non-parameterized
                        pooling if true, otherwise use conv3d
        :param concat_skip: when upsampling, concatenate skipped
                            tensor if true, otherwise use addition
        :param name: name of the backbone.
        :param kwargs: additional arguments.
        """
        super().__init__(
            image_size=image_size,
            out_channels=out_channels,
            num_channel_initial=num_channel_initial,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            name=name,
            **kwargs,
        )

        # init layer variables
        num_channels = [num_channel_initial * (2**d) for d in range(depth + 1)]

        self._num_channel_initial = num_channel_initial
        self._depth = depth
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[d],
                                        pooling=pooling) for d in range(depth)
        ]
        self._bottom_conv3d = layer.Conv3dBlock(filters=num_channels[depth])
        self._bottom_res3d = layer.Residual3dBlock(filters=num_channels[depth])
        self._upsample_blocks = [
            layer.UpSampleResnetBlock(filters=num_channels[d],
                                      concat=concat_skip) for d in range(depth)
        ]
        self._output_conv3d = layer.Conv3dWithResize(
            output_shape=image_size,
            filters=out_channels,
            kernel_initializer=out_kernel_initializer,
            activation=out_activation,
        )
예제 #3
0
    def build_decode_conv_block(
            self, filters: int, kernel_size: int,
            padding: str) -> Union[tf.keras.Model, tfkl.Layer]:
        """
        Build a conv block for up-sampling

        This block do not change the tensor shape (width, height, depth),
        it only changes the number of channels.

        :param filters: number of channels for output
        :param kernel_size: arg for conv3d
        :param padding: arg for conv3d
        :return: a block consists of one or multiple layers
        """
        return tf.keras.Sequential([
            layer.Conv3dBlock(
                filters=filters,
                kernel_size=kernel_size,
                padding=padding,
            ),
            layer.ResidualConv3dBlock(
                filters=filters,
                kernel_size=kernel_size,
                padding=padding,
            ),
        ])
예제 #4
0
    def build_down_sampling_block(
            self, filters: int, kernel_size: int, padding: str,
            strides: int) -> Union[tf.keras.Model, tfkl.Layer]:
        """
        Build a block for down-sampling.

        This block changes the tensor shape (width, height, depth),
        but it does not changes the number of channels.

        :param filters: number of channels for output, arg for conv3d
        :param kernel_size: arg for pool3d or conv3d
        :param padding: arg for pool3d or conv3d
        :param strides: arg for pool3d or conv3d
        :return: a block consists of one or multiple layers
        """
        if self._pooling:
            return tfkl.MaxPool3D(pool_size=kernel_size,
                                  strides=strides,
                                  padding=padding)
        else:
            return layer.Conv3dBlock(
                filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding=padding,
            )
예제 #5
0
def test_localNetUpSampleResnetBlock():
    """
    Test the layer.LocalNetUpSampleResnetBlock class, its default attributes and its call() function.
    """
    batch_size = 5
    channels = 4
    input_size = (32, 32, 16)
    output_size = (64, 64, 32)

    nonskip_tensor_size = (batch_size, ) + input_size + (channels, )
    skip_tensor_size = (batch_size, ) + output_size + (channels, )

    # Test __init__() and build()
    model = layer.LocalNetUpSampleResnetBlock(8)
    model.build([nonskip_tensor_size, skip_tensor_size])

    assert model._filters == 8
    assert model._use_additive_upsampling is True

    assert isinstance(model._deconv3d_block, type(layer.Deconv3dBlock(8)))
    assert isinstance(model._additive_upsampling,
                      type(layer.AdditiveUpSampling(output_size)))
    assert isinstance(model._conv3d_block, type(layer.Conv3dBlock(8)))
    assert isinstance(model._residual_block,
                      type(layer.LocalNetResidual3dBlock(8)))
예제 #6
0
def test_init_GlobalNet():
    """
    Testing init of GlobalNet is built as expected.
    """
    #  Initialising GlobalNet instance
    global_test = g.GlobalNet(
        image_size=[1, 2, 3],
        out_channels=3,
        num_channel_initial=3,
        extract_levels=[1, 2, 3],
        out_kernel_initializer="softmax",
        out_activation="softmax",
    )

    # Asserting initialised var for extract_levels is the same - Pass
    assert global_test._extract_levels == [1, 2, 3]
    # Asserting initialised var for extract_max_level is the same - Pass
    assert global_test._extract_max_level == 3

    # self reference grid
    # assert global_test.reference_grid correct shape, Pass
    assert global_test.reference_grid.shape == [1, 2, 3, 3]
    #  assert correct reference grid returned, Pass
    expected_ref_grid = tf.convert_to_tensor(
        [[
            [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
            [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
        ]],
        dtype=tf.float32,
    )
    assert is_equal_tf(global_test.reference_grid, expected_ref_grid)

    # Testing constant initializer
    #  We initialize the expected tensor and initialise another from the
    #  class variable using tf.Variable
    test_tensor_return = tf.convert_to_tensor(
        [[1.0, 0.0], [0.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0],
         [1.0, 0.0]],
        dtype=tf.float32,
    )
    global_return = tf.Variable(
        global_test.transform_initial(shape=[6, 2], dtype=tf.float32))

    # Asserting they are equal - Pass
    assert is_equal_tf(test_tensor_return,
                       tf.convert_to_tensor(global_return, dtype=tf.float32))

    # Assert downsample blocks type is correct, Pass
    assert all(
        isinstance(item, type(layer.DownSampleResnetBlock(12)))
        for item in global_test._downsample_blocks)
    #  Assert number of downsample blocks is correct (== max level), Pass
    assert len(global_test._downsample_blocks) == 3

    #  Assert conv3dBlock type is correct, Pass
    assert isinstance(global_test._conv3d_block, type(layer.Conv3dBlock(12)))

    #  Asserting type is dense_layer, Pass
    assert isinstance(global_test._dense_layer, type(layer.Dense(12)))
예제 #7
0
def test_downsampleResnetBlock():
    """
    Test the layer.DownSampleResnetBlock class and its default attributes. No need to test the call() function since a
    concatenation of tensorflow classes
    """
    model = layer.DownSampleResnetBlock(8)

    assert model._pooling is True

    assert isinstance(model._conv3d_block, type(layer.Conv3dBlock(8)))
    assert isinstance(model._residual_block, type(layer.Residual3dBlock(8)))
    assert isinstance(model._max_pool3d, type(layer.MaxPool3d(2)))
    assert model._conv3d_block3 is None

    model = layer.DownSampleResnetBlock(8, pooling=False)
    assert model._max_pool3d is None
    assert isinstance(model._conv3d_block3, type(layer.Conv3dBlock(8)))
예제 #8
0
    def __init__(
        self,
        image_size: tuple,
        out_channels: int,
        num_channel_initial: int,
        extract_levels: List[int],
        out_kernel_initializer: str,
        out_activation: str,
        name: str = "GlobalNet",
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E.
        Then, a densely-connected layer outputs an affine
        transformation.

        :param image_size: tuple, such as (dim1, dim2, dim3)
        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param extract_levels: list, which levels from net to extract
        :param out_kernel_initializer: not used
        :param out_activation: not used
        :param name: name of the backbone.
        :param kwargs: additional arguments.
        """
        super().__init__(
            image_size=image_size,
            out_channels=out_channels,
            num_channel_initial=num_channel_initial,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            name=name,
            **kwargs,
        )

        # save parameters
        assert out_channels == 3
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self.reference_grid = layer_util.get_reference_grid(image_size)
        self.transform_initial = tf.constant_initializer(
            value=list(np.eye(4, 3).reshape((-1))))
        # init layer variables
        num_channels = [
            num_channel_initial * (2**level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[i],
                                        kernel_size=7 if i == 0 else 3)
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(
            filters=num_channels[-1])  # level E
        self._dense_layer = layer.Dense(
            units=12, bias_initializer=self.transform_initial)
예제 #9
0
    def __init__(
        self,
        image_size,
        out_channels,
        num_channel_initial,
        depth,
        out_kernel_initializer,
        out_activation,
        pooling=True,
        concat_skip=False,
        **kwargs,
    ):
        """
        Initialise UNet.

        :param image_size: list, [f_dim1, f_dim2, f_dim3], dims of input image.
        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param depth: int, input is at level 0, bottom is at level depth
        :param out_kernel_initializer: str, which kernel to use as initialiser
        :param out_activation: str, activation at last layer
        :param pooling: Boolean, for downsampling, use non-parameterized
                        pooling if true, otherwise use conv3d
        :param concat_skip: Boolean, when upsampling, concatenate skipped
                            tensor if true, otherwise use addition
        :param kwargs:
        """
        super(UNet, self).__init__(**kwargs)

        # init layer variables
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]

        self._num_channel_initial = num_channel_initial
        self._depth = depth
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[d], pooling=pooling)
            for d in range(depth)
        ]
        self._bottom_conv3d = layer.Conv3dBlock(filters=num_channels[depth])
        self._bottom_res3d = layer.Residual3dBlock(filters=num_channels[depth])
        self._upsample_blocks = [
            layer.UpSampleResnetBlock(filters=num_channels[d], concat=concat_skip)
            for d in range(depth)
        ]
        self._output_conv3d = layer.Conv3dWithResize(
            output_shape=image_size,
            filters=out_channels,
            kernel_initializer=out_kernel_initializer,
            activation=out_activation,
        )
예제 #10
0
    def __init__(
        self,
        image_size,
        out_channels,
        num_channel_initial,
        extract_levels,
        out_kernel_initializer,
        out_activation,
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E.
        Then, a densely-connected layer outputs an affine
        transformation.

        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param extract_levels: list, which levels from net to extract
        :param out_activation: str, activation at last layer
        :param out_kernel_initializer: str, which kernel to use as initialiser
        :param kwargs:
        """
        super(GlobalNet, self).__init__(**kwargs)

        # save parameters
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self.reference_grid = layer_util.get_reference_grid(image_size)
        self.transform_initial = tf.constant_initializer(
            value=[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
        )

        # init layer variables
        num_channels = [
            num_channel_initial * (2 ** level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(
                filters=num_channels[i], kernel_size=7 if i == 0 else 3
            )
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(filters=num_channels[-1])  # level E
        self._dense_layer = layer.Dense(
            units=12, bias_initializer=self.transform_initial
        )
        self._reshape = tf.keras.layers.Reshape(target_shape=(4, 3))
예제 #11
0
def test_conv3d_block():
    """
    Test the layer.Conv3dBlock class and its default attributes.
    """

    conv3d_block = layer.Conv3dBlock(8)

    assert isinstance(conv3d_block._conv3d, layer.Conv3d)

    assert conv3d_block._conv3d._conv3d.kernel_size == (3, 3, 3)
    assert conv3d_block._conv3d._conv3d.strides == (1, 1, 1)
    assert conv3d_block._conv3d._conv3d.padding == "same"
    assert conv3d_block._conv3d._conv3d.use_bias is False

    assert isinstance(conv3d_block._act._act, type(tf.keras.activations.relu))
    assert isinstance(conv3d_block._norm._norm, tf.keras.layers.BatchNormalization)
예제 #12
0
파일: local_net.py 프로젝트: vsaase/DeepReg
    def build_bottom_block(self, filters: int, kernel_size: int,
                           padding: str) -> Union[tf.keras.Model, tfkl.Layer]:
        """
        Build a block for bottom layer.

        This block do not change the tensor shape (width, height, depth),
        it only changes the number of channels.

        :param filters: number of channels for output
        :param kernel_size: arg for conv3d
        :param padding: arg for conv3d
        :return: a block consists of one or multiple layers
        """
        return layer.Conv3dBlock(filters=filters,
                                 kernel_size=kernel_size,
                                 padding=padding)
예제 #13
0
def test_init_UNet():
    """
    Testing init of UNet as expected
    """
    local_test = u.UNet(
        image_size=[1, 2, 3],
        out_channels=3,
        num_channel_initial=3,
        depth=5,
        out_kernel_initializer="he_normal",
        out_activation="softmax",
    )

    #  Asserting num channels initial is the same, Pass
    assert local_test._num_channel_initial == 3

    #  Asserting depth is the same, Pass
    assert local_test._depth == 5

    # Assert downsample blocks type is correct, Pass
    assert all(
        isinstance(item, type(layer.DownSampleResnetBlock(12)))
        for item in local_test._downsample_blocks
    )
    #  Assert number of downsample blocks is correct (== depth), Pass
    assert len(local_test._downsample_blocks) == 5

    #  Assert bottom_conv3d type is correct, Pass
    assert isinstance(local_test._bottom_conv3d, type(layer.Conv3dBlock(5)))

    # Assert bottom res3d type is correct, Pass
    assert isinstance(local_test._bottom_res3d, type(layer.Residual3dBlock(5)))
    # Assert upsample blocks type is correct, Pass
    assert all(
        isinstance(item, type(layer.UpSampleResnetBlock(12)))
        for item in local_test._upsample_blocks
    )
    #  Assert number of upsample blocks is correct (== depth), Pass
    assert len(local_test._upsample_blocks) == 5

    # Assert output_conv3d is correct type, Pass
    assert isinstance(
        local_test._output_conv3d, type(layer.Conv3dWithResize([1, 2, 3], 3))
    )
예제 #14
0
def test_residual3dBlock():
    """
    Test the layer.Residual3dBlock class and its default attributes. No need to test the call() function since a
    concatenation of tensorflow classes
    """
    res3dBlock = layer.Residual3dBlock(8)

    assert isinstance(res3dBlock._conv3d_block, type(layer.Conv3dBlock(8)))
    assert res3dBlock._conv3d_block._conv3d._conv3d.kernel_size == (3, 3, 3)
    assert res3dBlock._conv3d_block._conv3d._conv3d.strides == (1, 1, 1)

    assert isinstance(res3dBlock._conv3d, type(layer.Conv3d(8)))
    assert res3dBlock._conv3d._conv3d.use_bias is False
    assert res3dBlock._conv3d._conv3d.kernel_size == (3, 3, 3)
    assert res3dBlock._conv3d._conv3d.strides == (1, 1, 1)

    assert isinstance(res3dBlock._act._act, type(tf.keras.activations.relu))
    assert isinstance(res3dBlock._norm._norm,
                      type(tf.keras.layers.BatchNormalization()))
예제 #15
0
    def __init__(self,
                 image_size, out_channels,
                 num_channel_initial, extract_levels,
                 out_kernel_initializer, out_activation,
                 **kwargs):
        """
        image is encoded gradually, i from level 0 to E
        then it is decoded gradually, j from level E to D
        some of the decoded level are used for generating extractions

        so extract_levels are between [0, E] with E = max(extract_levels) and D = min(extract_levels)

        :param out_channels: number of channels for the extractions
        :param num_channel_initial:
        :param extract_levels:
        :param out_kernel_initializer:
        :param out_activation:
        :param kwargs:
        """
        super(LocalNet, self).__init__(**kwargs)

        # save parameters
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self._extract_min_level = min(self._extract_levels)  # D

        # init layer variables

        nc = [num_channel_initial * (2 ** level) for level in range(self._extract_max_level + 1)]  # level 0 to E
        self._downsample_blocks = [layer.DownSampleResnetBlock(filters=nc[i], kernel_size=7 if i == 0 else 3)
                                   for i in range(self._extract_max_level)]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(filters=nc[-1])  # level E

        self._upsample_blocks = [layer.LocalNetUpSampleResnetBlock(nc[level]) for level in
                                 range(self._extract_max_level - 1, self._extract_min_level - 1, -1)]  # level D to E-1

        self._extract_layers = [
            # if kernels are not initialized by zeros, with init NN, extract may be too large
            layer.Conv3dWithResize(output_shape=image_size, filters=out_channels,
                                   kernel_initializer=out_kernel_initializer,
                                   activation=out_activation)
            for _ in self._extract_levels]
예제 #16
0
def test_upsampleResnetBlock():
    """
    Test the layer.UpSampleResnetBlock class and its default attributes. No need to test the call() function since a
    concatenation of tensorflow classes
    """
    batch_size = 5
    channels = 4
    input_size = (32, 32, 16)
    output_size = (64, 64, 32)

    input_tensor_size = (batch_size, ) + input_size + (channels, )
    skip_tensor_size = (batch_size, ) + output_size + (channels // 2, )

    model = layer.UpSampleResnetBlock(8)
    model.build([input_tensor_size, skip_tensor_size])

    assert model._filters == 8
    assert model._concat is False
    assert isinstance(model._conv3d_block, type(layer.Conv3dBlock(8)))
    assert isinstance(model._residual_block, type(layer.Residual3dBlock(8)))
    assert isinstance(model._deconv3d_block, type(layer.Deconv3dBlock(8)))
예제 #17
0
    def __init__(
        self,
        image_size: tuple,
        out_channels: int,
        num_channel_initial: int,
        extract_levels: List[int],
        out_kernel_initializer: str,
        out_activation: str,
        name: str = "LocalNet",
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E,
        then it is decoded gradually, j from level E to D.
        Some of the decoded levels are used for generating extractions.

        So, extract_levels are between [0, E] with E = max(extract_levels),
        and D = min(extract_levels).

        :param image_size: such as (dim1, dim2, dim3)
        :param out_channels: number of channels for the extractions
        :param num_channel_initial: number of initial channels.
        :param extract_levels: number of extraction levels.
        :param out_kernel_initializer: initializer to use for kernels.
        :param out_activation: activation to use at end layer.
        :param name: name of the backbone.
        :param kwargs: additional arguments.
        """
        super().__init__(
            image_size=image_size,
            out_channels=out_channels,
            num_channel_initial=num_channel_initial,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            name=name,
            **kwargs,
        )

        # save parameters
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self._extract_min_level = min(self._extract_levels)  # D

        # init layer variables
        num_channels = [
            num_channel_initial * (2**level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[i],
                                        kernel_size=7 if i == 0 else 3)
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(
            filters=num_channels[-1])  # level E

        self._upsample_blocks = [
            layer.LocalNetUpSampleResnetBlock(num_channels[level])
            for level in range(self._extract_max_level -
                               1, self._extract_min_level - 1, -1)
        ]  # level D to E-1

        self._extract_layers = [
            # if kernels are not initialized by zeros, with init NN, extract may be too large
            layer.Conv3dWithResize(
                output_shape=image_size,
                filters=out_channels,
                kernel_initializer=out_kernel_initializer,
                activation=out_activation,
            ) for _ in self._extract_levels
        ]
예제 #18
0
파일: u_net.py 프로젝트: zy20030535/DeepReg
    def __init__(
        self,
        image_size: tuple,
        out_channels: int,
        num_channel_initial: int,
        depth: int,
        out_kernel_initializer: str,
        out_activation: str,
        pooling: bool = True,
        concat_skip: bool = False,
        control_points: (tuple, None) = None,
        **kwargs,
    ):
        """
        Initialise UNet.

        :param image_size: tuple, (dim1, dim2, dim3), dims of input image.
        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param depth: int, input is at level 0, bottom is at level depth
        :param out_kernel_initializer: str, which kernel to use as initializer
        :param out_activation: str, activation at last layer
        :param pooling: Boolean, for downsampling, use non-parameterized
                        pooling if true, otherwise use conv3d
        :param concat_skip: Boolean, when upsampling, concatenate skipped
                            tensor if true, otherwise use addition
        :param control_points: (tuple, None), specify the distance between control points (in voxels).
        :param kwargs:
        """
        super(UNet, self).__init__(**kwargs)

        # init layer variables
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]

        self._num_channel_initial = num_channel_initial
        self._depth = depth
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[d], pooling=pooling)
            for d in range(depth)
        ]
        self._bottom_conv3d = layer.Conv3dBlock(filters=num_channels[depth])
        self._bottom_res3d = layer.Residual3dBlock(filters=num_channels[depth])
        self._upsample_blocks = [
            layer.UpSampleResnetBlock(filters=num_channels[d], concat=concat_skip)
            for d in range(depth)
        ]
        self._output_conv3d = layer.Conv3dWithResize(
            output_shape=image_size,
            filters=out_channels,
            kernel_initializer=out_kernel_initializer,
            activation=out_activation,
        )

        self.resize = (
            layer.ResizeCPTransform(control_points)
            if control_points is not None
            else False
        )
        self.interpolate = (
            layer.BSplines3DTransform(control_points, image_size)
            if control_points is not None
            else False
        )
예제 #19
0
    def __init__(
        self,
        image_size: tuple,
        out_channels: int,
        num_channel_initial: int,
        extract_levels: List[int],
        out_kernel_initializer: str,
        out_activation: str,
        control_points: (tuple, None) = None,
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E,
        then it is decoded gradually, j from level E to D.
        Some of the decoded levels are used for generating extractions.

        So, extract_levels are between [0, E] with E = max(extract_levels),
        and D = min(extract_levels).

        :param image_size: tuple, such as (dim1, dim2, dim3)
        :param out_channels: int, number of channels for the extractions
        :param num_channel_initial: int, number of initial channels.
        :param extract_levels: list of int, number of extraction levels.
        :param out_kernel_initializer: str, initializer to use for kernels.
        :param out_activation: str, activation to use at end layer.
        :param control_points: (tuple, None), specify the distance between control points (in voxels).
        :param kwargs:
        """
        super(LocalNet, self).__init__(**kwargs)

        # save parameters
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self._extract_min_level = min(self._extract_levels)  # D

        # init layer variables

        num_channels = [
            num_channel_initial * (2**level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[i],
                                        kernel_size=7 if i == 0 else 3)
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(
            filters=num_channels[-1])  # level E

        self._upsample_blocks = [
            layer.LocalNetUpSampleResnetBlock(num_channels[level])
            for level in range(self._extract_max_level -
                               1, self._extract_min_level - 1, -1)
        ]  # level D to E-1

        self._extract_layers = [
            # if kernels are not initialized by zeros, with init NN, extract may be too large
            layer.Conv3dWithResize(
                output_shape=image_size,
                filters=out_channels,
                kernel_initializer=out_kernel_initializer,
                activation=out_activation,
            ) for _ in self._extract_levels
        ]

        self.resize = (layer.ResizeCPTransform(control_points)
                       if control_points is not None else False)
        self.interpolate = (layer.BSplines3DTransform(control_points,
                                                      image_size)
                            if control_points is not None else False)