示例#1
0
def test_defaults_criterion_optimizer_scheduler():
    """
    Test on ConfigExperiment defaults.
    when {criterion, optimizer, scheduler}_params are specified
    the respective callback should be generated automatically
    """
    callbacks = DEFAULT_CALLBACKS.copy()
    callbacks["_criterion"] = CriterionCallback
    callbacks["_optimizer"] = OptimizerCallback
    callbacks["_scheduler"] = SchedulerCallback

    config = DEFAULT_MINIMAL_CONFIG.copy()
    config["stages"]["criterion_params"] = {"criterion": "BCEWithLogitsLoss"}
    config["stages"]["optimizer_params"] = {"optimizer": "SomeOptimizer"}
    config["stages"]["scheduler_params"] = {"scheduler": "SomeScheduler"}
    exp = ConfigExperiment(config=config)

    assert exp.initial_seed == 42
    assert exp.logdir == "./logdir"
    assert exp.stages == ["train"]
    assert exp.distributed_params == {}
    assert exp.get_stage_params("train") == {
        "logdir": "./logdir",
    }
    assert isinstance(exp.get_model("train"), SomeModel)
    assert exp.get_criterion("train") is not None
    assert exp.get_optimizer("train", SomeModel()) is not None
    assert exp.get_scheduler("train", None) is not None

    _test_callbacks(callbacks, exp)
示例#2
0
def test_defaults():
    """
    Test on ConfigExperiment defaults.
    It's pretty similar to BaseExperiment's test
    but the thing is that those two are very different classes and
    inherit from different parent classes.
    Also very important to check which callbacks are added by default
    """
    exp = ConfigExperiment(config=DEFAULT_MINIMAL_CONFIG.copy())

    assert exp.initial_seed == 42
    assert exp.logdir == "./logdir"
    assert exp.stages == ["train"]
    assert exp.distributed_params == {}
    assert exp.get_stage_params("train") == {
        "logdir": "./logdir",
    }
    assert isinstance(exp.get_model("train"), SomeModel)
    assert exp.get_criterion("train") is None
    assert exp.get_optimizer("train", SomeModel()) is None
    assert exp.get_scheduler("train", None) is None

    _test_callbacks(DEFAULT_CALLBACKS, exp)