예제 #1
0
def test_call_global_net():
    """
    Asserting that output shape of globalnet Call method
    is correct.
    """
    out = 3
    im_size = (1, 2, 3)
    batch_size = 5
    # initialising GlobalNet instance
    global_test = g.GlobalNet(
        image_size=im_size,
        out_channels=out,
        num_channel_initial=3,
        extract_levels=[1, 2, 3],
        out_kernel_initializer="softmax",
        out_activation="softmax",
    )
    # pass an input of all zeros
    inputs = tf.constant(
        np.zeros(
            (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32
        )
    )
    # get outputs by calling
    ddf, theta = global_test.call(inputs)
    assert ddf.shape == (batch_size, *im_size, 3)
    assert theta.shape == (batch_size, 4, 3)
예제 #2
0
def test_init_GlobalNet():
    """
    Testing init of GlobalNet is built as expected.
    """
    # Initialising GlobalNet instance
    global_test = g.GlobalNet(
        image_size=[1, 2, 3],
        out_channels=3,
        num_channel_initial=3,
        extract_levels=[1, 2, 3],
        out_kernel_initializer="softmax",
        out_activation="softmax",
    )

    # Asserting initialised var for extract_levels is the same - Pass
    assert global_test._extract_levels == [1, 2, 3]
    # Asserting initialised var for extract_max_level is the same - Pass
    assert global_test._extract_max_level == 3

    # self reference grid
    # assert global_test.reference_grid correct shape, Pass
    assert global_test.reference_grid.shape == [1, 2, 3, 3]
    # assert correct reference grid returned, Pass
    expected_ref_grid = tf.convert_to_tensor(
        [[
            [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]],
            [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 2.0]],
        ]],
        dtype=tf.float32,
    )
    assert is_equal_tf(global_test.reference_grid, expected_ref_grid)

    # Testing constant initializer
    # We initialize the expected tensor and initialise another from the
    # class variable using tf.Variable
    test_tensor_return = tf.convert_to_tensor(
        [[1.0, 0.0], [0.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0],
         [1.0, 0.0]],
        dtype=tf.float32,
    )
    global_return = tf.Variable(
        global_test.transform_initial(shape=[6, 2], dtype=tf.float32))

    # Asserting they are equal - Pass
    assert is_equal_tf(test_tensor_return,
                       tf.convert_to_tensor(global_return, dtype=tf.float32))

    # Assert downsample blocks type is correct, Pass
    assert all(
        isinstance(item, layer.DownSampleResnetBlock)
        for item in global_test._downsample_blocks)
    # Assert number of downsample blocks is correct (== max level), Pass
    assert len(global_test._downsample_blocks) == 3

    #  Assert conv3dBlock type is correct, Pass
    assert isinstance(global_test._conv3d_block, layer.Conv3dBlock)

    #  Asserting type is dense_layer, Pass
    assert isinstance(global_test._dense_layer, layer.Dense)
예제 #3
0
def test_call_GlobalNet():
    """
    Asserting that output shape of globalnet Call method
    is correct.
    """
    out = 3
    im_size = [1, 2, 3]
    #  Initialising GlobalNet instance
    global_test = g.GlobalNet(
        image_size=im_size,
        out_channels=out,
        num_channel_initial=3,
        extract_levels=[1, 2, 3],
        out_kernel_initializer="softmax",
        out_activation="softmax",
    )
    # Pass an input of all zeros
    inputs = np.zeros((5, im_size[0], im_size[1], im_size[2], out))
    #  Get outputs by calling
    output = global_test.call(inputs)
    #  Expected shape is (5, 1, 2, 3, 3)
    assert all(x == y for x, y in zip(inputs.shape, output.shape))
def test_global_return():
    """
    Testing that build_backbone func returns an object
    of type GlobalNet from backbone module when initialised
    with the associated GlobalNet config.
    """
    out = util.build_backbone(
        image_size=(1, 2, 3),
        out_channels=1,
        model_config={
            "backbone": "global",
            "global": {
                "num_channel_initial": 4,
                "extract_levels": [1, 2, 3]
            },
        },
        method_name="ddf",
    )
    assert isinstance(
        out,
        type(
            global_net.GlobalNet([1, 2, 3], 4, 4, [1, 2, 3], "he_normal",
                                 "sigmoid")),
    )