Esempio n. 1
0
def test_efficientnet_b0_return_tensor():
    block = blocks.EfficientNetBlock(version="b0", pretrained=False)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(nest.flatten(outputs)) == 1
Esempio n. 2
0
def test_efficientnet_wrong_version_error():
    with pytest.raises(ValueError) as info:
        blocks.EfficientNetBlock(version="abc")

    assert "Expect version to be" in str(info.value)