Exemple #1
0
def test_merge_defaults_dataset():
    cfg = {}
    result = config._merge_defaults(cfg.copy())

    assert isinstance(result, dict)
    assert result["dataset"]["transformer"] == d.TRANSFORMER
    assert result["dataset"]["augmentor"] == d.AUGMENTOR

    cfg = {
        "dataset": {
            "transformer": {
                "import": "IdentityTransformer"
            },
            "augmentor": {
                "import": "unit-test",
                "params": {
                    "unit": "test"
                }
            },
        }
    }
    result = config._merge_defaults(cfg.copy())
    assert result["dataset"] == cfg["dataset"]

    cfg = {"dataset": []}
    with pytest.raises(jsonschema.ValidationError):
        config._merge_defaults(cfg)
Exemple #2
0
def test_merge_defaults_solver():
    cfg = {}
    result = config._merge_defaults(cfg.copy())

    assert isinstance(result, dict)
    assert result["solver"]["batch_size"] == d.BATCH_SIZE
    assert result["solver"]["epochs"] == d.EPOCHS
    assert result["solver"]["optimizer"] == d.OPTIMIZER

    cfg = {"solver": {"batch_size": 42, "epochs": 7, "optimizer": {"import": "SGD"}}}
    result = config._merge_defaults(cfg.copy())
    assert result["solver"] == cfg["solver"]

    cfg = {"solver": []}
    with pytest.raises(jsonschema.ValidationError):
        config._merge_defaults(cfg)
Exemple #3
0
def test_merge_defaults_services():
    cfg = {}
    result = config._merge_defaults(cfg.copy())

    assert isinstance(result, dict)
    assert result["services"]["best_checkpoint"] == d.BEST_CHECKPOINT
    assert result["services"]["tensorboard"] == d.TENSORBOARD
    assert result["services"]["train_early_stopping"] == d.TRAIN_EARLY_STOPPING
    assert (result["services"]["validation_early_stopping"] ==
            d.VALIDATION_EARLY_STOPPING)

    cfg = {
        "services": {
            "best_checkpoint": {
                "monitor": "val_unit_test",
                "mode": "min"
            },
            "tensorboard": {
                "batch_size": 42
            },
            "train_early_stopping": {
                "monitor": "unit_test",
                "mode": "min",
                "min_delta": 1e-7,
                "patience": 42,
            },
            "validation_early_stopping": {
                "monitor": "val_unit_test",
                "mode": "max",
                "min_delta": 1e-7,
                "patience": 42,
            },
        }
    }
    result = config._merge_defaults(cfg.copy())
    assert result["services"] == cfg["services"]

    cfg = {"services": []}
    with pytest.raises(jsonschema.ValidationError):
        config._merge_defaults(cfg)