Esempio n. 1
0
def test_input_requirement():
    hp = hp_module.HyperParameters()
    with pytest.raises(ValueError, match=r".*must specify.*"):
        hm = aug_module.HyperImageAugment()

    hm = aug_module.HyperImageAugment(input_shape=(None, None, 3))
    model = hm.build(hp)
    assert model.built

    hm = aug_module.HyperImageAugment(input_tensor=keras.Input(shape=(32, 32,
                                                                      3)))
    model = hm.build(hp)
    assert model.built
Esempio n. 2
0
def test_model_construction_factor_zero():
    hp = hp_module.HyperParameters()
    hm = aug_module.HyperImageAugment(input_shape=(None, None, 3))
    model = hm.build(hp)
    # augment_layers search default space [0, 4], with default zero.
    assert len(model.layers) == 1

    hp = hp_module.HyperParameters()
    hm = aug_module.HyperImageAugment(input_shape=(None, None, 3),
                                      augment_layers=0)
    model = hm.build(hp)
    # factors default all zero, the model should only have input layer
    assert len(model.layers) == 1
Esempio n. 3
0
def test_hyperparameter_existence_and_hp_defaults_rand_aug():
    hp = hp_module.HyperParameters()
    hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3),
                                      augment_layers=[2, 5],
                                      contrast=False)
    hm.build(hp)
    assert hp.get("augment_layers") == 2
Esempio n. 4
0
def test_tf_version_too_low_error():
    pp_module = aug_module.preprocessing
    aug_module.preprocessing = None

    with pytest.raises(ImportError, match="HyperImageAugment requires"):
        aug_module.HyperImageAugment()

    aug_module.preprocessing = pp_module
Esempio n. 5
0
def test_hyperparameter_override_rand_aug():
    hp = hp_module.HyperParameters()
    hp.Fixed("randaug_mag", 1.0)
    hp.Choice("randaug_count", [4])
    hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3),
                                      augment_layers=[2, 4])
    hm.build(hp)
    assert hp.get("randaug_mag") == 1.0
    assert hp.get("randaug_count") == 4
Esempio n. 6
0
def test_transforms_search_space():
    hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3))
    # Default choice
    assert hm.transforms == [
        ("rotate", (0, 0.5)),
        ("translate_x", (0, 0.4)),
        ("translate_y", (0, 0.4)),
        ("contrast", (0, 0.3)),
    ]

    hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3),
                                      rotate=0.3,
                                      translate_x=[0.1, 0.5],
                                      contrast=None)
    assert hm.transforms == [
        ("rotate", (0, 0.3)),
        ("translate_x", (0.1, 0.5)),
        ("translate_y", (0, 0.4)),
    ]
Esempio n. 7
0
def test_hyperparameter_override_fixed_aug():
    hp = hp_module.HyperParameters()
    hp.Fixed("factor_rotate", 0.9)
    hp.Choice("factor_translate_x", [0.8])
    hm = aug_module.HyperImageAugment(input_shape=(32, 32, 3),
                                      augment_layers=0)
    hm.build(hp)
    assert hp.get("factor_rotate") == 0.9
    assert hp.get("factor_translate_x") == 0.8
    assert hp.get("factor_translate_y") == 0.0
    assert hp.get("factor_contrast") == 0.0
Esempio n. 8
0
def test_hyperparameter_selection_and_hp_defaults_fixed_aug():
    hp = hp_module.HyperParameters()
    hm = aug_module.HyperImageAugment(
        input_shape=(32, 32, 3),
        translate_x=[0.2, 0.4],
        contrast=None,
        augment_layers=0,
    )
    hm.build(hp)
    # default value of default search space are always minimum.
    assert hp.get("factor_rotate") == 0
    assert hp.get("factor_translate_x") == 0.2
    assert hp.get("factor_translate_y") == 0
    assert "factor_contrast" not in hp.values
Esempio n. 9
0
def test_model_construction_rand_aug():
    hp = hp_module.HyperParameters()
    hm = aug_module.HyperImageAugment(input_shape=(None, None, 3),
                                      rotate=[0.2, 0.5])
    model = hm.build(hp)
    assert model.layers
    assert model.name == "image_rand_augment"

    # Output shape includes batch dimension.
    assert model.output_shape == (None, None, None, 3)
    out = model.predict(np.ones((1, 32, 32, 3)))
    assert out.shape == (1, 32, 32, 3)
    # Augment does not distort image when inferencing.
    assert (out != 1).sum() == 0