コード例 #1
0
def test_init_UNet():
    """
    Testing init of UNet as expected
    """
    local_test = u.UNet(
        image_size=[1, 2, 3],
        out_channels=3,
        num_channel_initial=3,
        depth=5,
        out_kernel_initializer="he_normal",
        out_activation="softmax",
    )

    #  Asserting num channels initial is the same, Pass
    assert local_test._num_channel_initial == 3

    #  Asserting depth is the same, Pass
    assert local_test._depth == 5

    # Assert downsample blocks type is correct, Pass
    assert all(
        isinstance(item, type(layer.DownSampleResnetBlock(12)))
        for item in local_test._downsample_blocks
    )
    #  Assert number of downsample blocks is correct (== depth), Pass
    assert len(local_test._downsample_blocks) == 5

    #  Assert bottom_conv3d type is correct, Pass
    assert isinstance(local_test._bottom_conv3d, type(layer.Conv3dBlock(5)))

    # Assert bottom res3d type is correct, Pass
    assert isinstance(local_test._bottom_res3d, type(layer.Residual3dBlock(5)))
    # Assert upsample blocks type is correct, Pass
    assert all(
        isinstance(item, type(layer.UpSampleResnetBlock(12)))
        for item in local_test._upsample_blocks
    )
    #  Assert number of upsample blocks is correct (== depth), Pass
    assert len(local_test._upsample_blocks) == 5

    # Assert output_conv3d is correct type, Pass
    assert isinstance(
        local_test._output_conv3d, type(layer.Conv3dWithResize([1, 2, 3], 3))
    )
コード例 #2
0
def test_unet_return():
    """
    Testing that build_backbone func returns an object
    of type UNet form backbone module when initialised
    with the associated UNet config.
    """
    out = util.build_backbone(
        image_size=(1, 2, 3),
        out_channels=1,
        model_config={
            "backbone": "unet",
            "unet": {
                "num_channel_initial": 4,
                "depth": 4
            },
        },
        method_name="ddf",
    )
    assert isinstance(
        out, type(u_net.UNet([1, 2, 3], 4, 4, 4, "he_normal", "sigmoid")))
コード例 #3
0
 def test_call_unet(self, image_size, depth):
     out = 3
     # initialising UNet instance
     network = u.UNet(
         image_size=image_size,
         out_channels=3,
         num_channel_initial=2,
         depth=depth,
         out_kernel_initializer="he_normal",
         out_activation="softmax",
     )
     # pass an input of all zeros
     inputs = tf.constant(
         np.zeros(
             (5, image_size[0], image_size[1], image_size[2], out), dtype=np.float32
         )
     )
     # get outputs by calling
     output = network.call(inputs)
     # expected shape is (5, 1, 2, 3)
     assert all(x == y for x, y in zip(inputs.shape, output.shape))
コード例 #4
0
def test_call_UNet():
    """
    Asserting that output shape of UNet call method
    is correct.
    """
    out = 3
    im_size = [1, 2, 3]
    #  Initialising LocalNet instance
    global_test = u.UNet(
        image_size=im_size,
        out_channels=out,
        num_channel_initial=3,
        depth=6,
        out_kernel_initializer="glorot_uniform",
        out_activation="sigmoid",
    )
    # Pass an input of all zeros
    inputs = np.zeros((5, im_size[0], im_size[1], im_size[2], 3))
    #  Get outputs by calling
    output = global_test.call(inputs)
    #  Expected shape is (5, 1, 2, 3)
    assert all(x == y for x, y in zip(inputs.shape, output.shape))
コード例 #5
0
    def test_init(self, image_size, depth):
        network = u.UNet(
            image_size=image_size,
            out_channels=3,
            num_channel_initial=2,
            depth=depth,
            out_kernel_initializer="he_normal",
            out_activation="softmax",
        )

        # asserting num channels initial is the same, Pass
        assert network._num_channel_initial == 2

        # asserting depth is the same, Pass
        assert network._depth == depth

        # assert downsample blocks type is correct, Pass
        assert all(
            isinstance(item, layer.DownSampleResnetBlock)
            for item in network._downsample_blocks
        )
        # assert number of downsample blocks is correct (== depth), Pass
        assert len(network._downsample_blocks) == depth

        # assert bottom_conv3d type is correct, Pass
        assert isinstance(network._bottom_conv3d, layer.Conv3dBlock)

        # assert bottom res3d type is correct, Pass
        assert isinstance(network._bottom_res3d, layer.Residual3dBlock)
        # assert upsample blocks type is correct, Pass
        assert all(
            isinstance(item, layer.UpSampleResnetBlock)
            for item in network._upsample_blocks
        )
        # assert number of upsample blocks is correct (== depth), Pass
        assert len(network._upsample_blocks) == depth

        # assert output_conv3d is correct type, Pass
        assert isinstance(network._output_conv3d, layer.Conv3dWithResize)