def test_augment_get_config_has_all_attributes(): block = blocks.ImageAugmentation() config = block.get_config() assert test_utils.get_func_args( blocks.ImageAugmentation.__init__).issubset(config.keys())
def test_augment_build_return_tensor(): block = blocks.ImageAugmentation() outputs = block.build(kerastuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32)) assert len(nest.flatten(outputs)) == 1 assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_augment_build_with_contrast_factor_return_tensor(): block = blocks.ImageAugmentation(contrast_factor=0.1) outputs = block.build( keras_tuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1
def test_augment_build_with_vflip_only_return_tensor(): block = blocks.ImageAugmentation(vertical_flip=True, horizontal_flip=False) outputs = block.build( keras_tuner.HyperParameters(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1
def test_augment_deserialize_to_augment(): serialized_block = blocks.serialize(blocks.ImageAugmentation()) block = blocks.deserialize(serialized_block) assert isinstance(block, blocks.ImageAugmentation)