def test_defaults_criterion_optimizer_scheduler(): """ Test on ConfigExperiment defaults. when {criterion, optimizer, scheduler}_params are specified the respective callback should be generated automatically """ callbacks = DEFAULT_CALLBACKS.copy() callbacks["_criterion"] = CriterionCallback callbacks["_optimizer"] = OptimizerCallback callbacks["_scheduler"] = SchedulerCallback config = DEFAULT_MINIMAL_CONFIG.copy() config["stages"]["criterion_params"] = {"criterion": "BCEWithLogitsLoss"} config["stages"]["optimizer_params"] = {"optimizer": "SomeOptimizer"} config["stages"]["scheduler_params"] = {"scheduler": "SomeScheduler"} exp = ConfigExperiment(config=config) assert exp.initial_seed == 42 assert exp.logdir == "./logdir" assert exp.stages == ["train"] assert exp.distributed_params == {} assert exp.get_stage_params("train") == { "logdir": "./logdir", } assert isinstance(exp.get_model("train"), SomeModel) assert exp.get_criterion("train") is not None assert exp.get_optimizer("train", SomeModel()) is not None assert exp.get_scheduler("train", None) is not None _test_callbacks(callbacks, exp)
def test_defaults(): """ Test on ConfigExperiment defaults. It's pretty similar to BaseExperiment's test but the thing is that those two are very different classes and inherit from different parent classes. Also very important to check which callbacks are added by default """ exp = ConfigExperiment(config=DEFAULT_MINIMAL_CONFIG) assert exp.initial_seed == 42 assert exp.logdir is None assert exp.stages == ["train"] assert exp.distributed_params == {} assert exp.monitoring_params == {} assert exp.get_state_params("train") == { "logdir": None, } assert isinstance(exp.get_model("train"), SomeModel) assert exp.get_criterion("train") is None assert exp.get_optimizer("train", SomeModel()) is None assert exp.get_scheduler("train", None) is None assert exp.get_callbacks("train").keys() == DEFAULT_CALLBACKS.keys() cbs = zip(exp.get_callbacks("train").values(), DEFAULT_CALLBACKS.values()) for c1, klass in cbs: assert isinstance(c1, klass)
def test_defaults(): """ Test on ConfigExperiment defaults. It's pretty similar to BaseExperiment's test but the thing is that those two are very different classes and inherit from different parent classes. Also very important to check which callbacks are added by default """ exp = ConfigExperiment(config=DEFAULT_MINIMAL_CONFIG.copy()) assert exp.initial_seed == 42 assert exp.logdir == "./logdir" assert exp.stages == ["train"] assert exp.distributed_params == {} assert exp.get_stage_params("train") == { "logdir": "./logdir", } assert isinstance(exp.get_model("train"), SomeModel) assert exp.get_criterion("train") is None assert exp.get_optimizer("train", SomeModel()) is None assert exp.get_scheduler("train", None) is None _test_callbacks(DEFAULT_CALLBACKS, exp)