def test_resnet_get_config_has_all_attributes(): block = blocks.ResNetBlock() config = block.get_config() assert utils.get_func_args( blocks.ResNetBlock.__init__).issubset(config.keys())
def test_resnet_build_return_tensor(): block = blocks.ResNetBlock() outputs = block.build(kerastuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32)) assert len(nest.flatten(outputs)) == 1 assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_resnet_pretrained_error_with_two_channels(): block = blocks.ResNetBlock(pretrained=True) with pytest.raises(ValueError) as info: block.build(kerastuner.HyperParameters(), tf.keras.Input(shape=(224, 224, 2), dtype=tf.float32)) assert 'When pretrained is set to True' in str(info.value)
def test_resnet_pretrained_with_one_channel_input(): block = blocks.ResNetBlock(pretrained=True) outputs = block.build(kerastuner.HyperParameters(), tf.keras.Input(shape=(28, 28, 1), dtype=tf.float32)) assert len(nest.flatten(outputs)) == 1 assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_resnet_pretrained_build_return_tensor(): block = blocks.ResNetBlock(pretrained=True) outputs = block.build( keras_tuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1
def test_resnet_v1_return_tensor(): block = blocks.ResNetBlock(version="v1") outputs = block.build( keras_tuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1
def test_resnet_init_error_with_input_shape(): with pytest.raises(ValueError) as info: blocks.ResNetBlock(input_shape=(10,)) assert 'Argument "input_shape" is not' in str(info.value)
def test_resnet_init_error_with_include_top(): with pytest.raises(ValueError) as info: blocks.ResNetBlock(include_top=True) assert 'Argument "include_top" is not' in str(info.value)
def test_resnet_deserialize_to_resnet(): serialized_block = blocks.serialize(blocks.ResNetBlock()) block = blocks.deserialize(serialized_block) assert isinstance(block, blocks.ResNetBlock)
def test_resnet_wrong_version_error(): with pytest.raises(ValueError) as info: blocks.ResNetBlock(version="abc") assert "Expect version to be" in str(info.value)