def test_efficientnet_b0_return_tensor(): block = blocks.EfficientNetBlock(version="b0", pretrained=False) outputs = block.build( keras_tuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1
def test_efficientnet_wrong_version_error(): with pytest.raises(ValueError) as info: blocks.EfficientNetBlock(version="abc") assert "Expect version to be" in str(info.value)