def test_input_requirement(): hp = hp_module.HyperParameters() with pytest.raises(ValueError, match=r".*must specify.*"): hm = aug_module.HyperImageAugment() hm = aug_module.HyperImageAugment(input_shape=(None, None, 3)) model = hm.build(hp) assert model.built hm = aug_module.HyperImageAugment(input_tensor=keras.Input(shape=(32, 32, 3))) model = hm.build(hp) assert model.built
def test_model_construction_factor_zero(): hp = hp_module.HyperParameters() hm = aug_module.HyperImageAugment(input_shape=(None, None, 3)) model = hm.build(hp) # augment_layers search default space [0, 4], with default zero. assert len(model.layers) == 1 hp = hp_module.HyperParameters() hm = aug_module.HyperImageAugment(input_shape=(None, None, 3), augment_layers=0) model = hm.build(hp) # factors default all zero, the model should only have input layer assert len(model.layers) == 1
def test_hyperparameter_existence_and_hp_defaults_rand_aug(): hp = hp_module.HyperParameters() hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3), augment_layers=[2, 5], contrast=False) hm.build(hp) assert hp.get("augment_layers") == 2
def test_tf_version_too_low_error(): pp_module = aug_module.preprocessing aug_module.preprocessing = None with pytest.raises(ImportError, match="HyperImageAugment requires"): aug_module.HyperImageAugment() aug_module.preprocessing = pp_module
def test_hyperparameter_override_rand_aug(): hp = hp_module.HyperParameters() hp.Fixed("randaug_mag", 1.0) hp.Choice("randaug_count", [4]) hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3), augment_layers=[2, 4]) hm.build(hp) assert hp.get("randaug_mag") == 1.0 assert hp.get("randaug_count") == 4
def test_transforms_search_space(): hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3)) # Default choice assert hm.transforms == [ ("rotate", (0, 0.5)), ("translate_x", (0, 0.4)), ("translate_y", (0, 0.4)), ("contrast", (0, 0.3)), ] hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3), rotate=0.3, translate_x=[0.1, 0.5], contrast=None) assert hm.transforms == [ ("rotate", (0, 0.3)), ("translate_x", (0.1, 0.5)), ("translate_y", (0, 0.4)), ]
def test_hyperparameter_override_fixed_aug(): hp = hp_module.HyperParameters() hp.Fixed("factor_rotate", 0.9) hp.Choice("factor_translate_x", [0.8]) hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3), augment_layers=0) hm.build(hp) assert hp.get("factor_rotate") == 0.9 assert hp.get("factor_translate_x") == 0.8 assert hp.get("factor_translate_y") == 0.0 assert hp.get("factor_contrast") == 0.0
def test_hyperparameter_selection_and_hp_defaults_fixed_aug(): hp = hp_module.HyperParameters() hm = aug_module.HyperImageAugment( input_shape=(32, 32, 3), translate_x=[0.2, 0.4], contrast=None, augment_layers=0, ) hm.build(hp) # default value of default search space are always minimum. assert hp.get("factor_rotate") == 0 assert hp.get("factor_translate_x") == 0.2 assert hp.get("factor_translate_y") == 0 assert "factor_contrast" not in hp.values
def test_model_construction_rand_aug(): hp = hp_module.HyperParameters() hm = aug_module.HyperImageAugment(input_shape=(None, None, 3), rotate=[0.2, 0.5]) model = hm.build(hp) assert model.layers assert model.name == "image_rand_augment" # Output shape includes batch dimension. assert model.output_shape == (None, None, None, 3) out = model.predict(np.ones((1, 32, 32, 3))) assert out.shape == (1, 32, 32, 3) # Augment does not distort image when inferencing. assert (out != 1).sum() == 0