Пример #1
0
def test_resnet_get_config_has_all_attributes():
    block = blocks.ResNetBlock()

    config = block.get_config()

    assert utils.get_func_args(
        blocks.ResNetBlock.__init__).issubset(config.keys())
Пример #2
0
def test_resnet_build_return_tensor():
    block = blocks.ResNetBlock()

    outputs = block.build(kerastuner.HyperParameters(),
                          tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32))

    assert len(nest.flatten(outputs)) == 1
    assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
Пример #3
0
def test_resnet_pretrained_error_with_two_channels():
    block = blocks.ResNetBlock(pretrained=True)

    with pytest.raises(ValueError) as info:
        block.build(kerastuner.HyperParameters(),
                    tf.keras.Input(shape=(224, 224, 2), dtype=tf.float32))

    assert 'When pretrained is set to True' in str(info.value)
Пример #4
0
def test_resnet_pretrained_with_one_channel_input():
    block = blocks.ResNetBlock(pretrained=True)

    outputs = block.build(kerastuner.HyperParameters(),
                          tf.keras.Input(shape=(28, 28, 1), dtype=tf.float32))

    assert len(nest.flatten(outputs)) == 1
    assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
Пример #5
0
def test_resnet_pretrained_build_return_tensor():
    block = blocks.ResNetBlock(pretrained=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(nest.flatten(outputs)) == 1
Пример #6
0
def test_resnet_v1_return_tensor():
    block = blocks.ResNetBlock(version="v1")

    outputs = block.build(
        keras_tuner.HyperParameters(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(nest.flatten(outputs)) == 1
Пример #7
0
def test_resnet_init_error_with_input_shape():
    with pytest.raises(ValueError) as info:
        blocks.ResNetBlock(input_shape=(10,))

    assert 'Argument "input_shape" is not' in str(info.value)
Пример #8
0
def test_resnet_init_error_with_include_top():
    with pytest.raises(ValueError) as info:
        blocks.ResNetBlock(include_top=True)

    assert 'Argument "include_top" is not' in str(info.value)
Пример #9
0
def test_resnet_deserialize_to_resnet():
    serialized_block = blocks.serialize(blocks.ResNetBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ResNetBlock)
Пример #10
0
def test_resnet_wrong_version_error():
    with pytest.raises(ValueError) as info:
        blocks.ResNetBlock(version="abc")

    assert "Expect version to be" in str(info.value)