def _test(to_save, obj, name): save_handler = MagicMock() save_handler.remove = MagicMock() checkpointer = Checkpoint(to_save, save_handler=save_handler, score_name="loss", score_function=lambda e: e.state.score) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1, score=-0.77) checkpointer(trainer) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_loss=-0.7700.pth".format(name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 trainer.state.score = -0.76 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_loss=-0.7600.pth".format(name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with( "{}_loss=-0.7700.pth".format(name)) assert checkpointer.last_checkpoint == "{}_loss=-0.7600.pth".format( name)
def _test(to_save, obj, name, score_name=None): save_handler = MagicMock() save_handler.remove = MagicMock() checkpointer = Checkpoint(to_save, save_handler=save_handler, score_name=score_name, score_function=lambda e: e.state.epoch) if score_name is None: score_name = "" else: score_name += "=" trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_{}1.pth".format(name, score_name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_{}12.pth".format(name, score_name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_{}1.pth".format( name, score_name)) assert checkpointer.last_checkpoint == "{}_{}12.pth".format( name, score_name)
def test_mlflow_bad_metric_name_handling(dirname): import mlflow true_values = [123.0, 23.4, 333.4] with MLflowLogger(str(dirname / "mlruns")) as mlflow_logger: active_run = mlflow.active_run() handler = OutputHandler(tag="training", metric_names="all") engine = Engine(lambda e, b: None) engine.state = State(metrics={"metric:0 in %": 123.0, "metric 0": 1000.0}) with pytest.warns(UserWarning, match=r"MLflowLogger output_handler encountered an invalid metric name"): engine.state.epoch = 1 handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) for _, v in enumerate(true_values): engine.state.epoch += 1 engine.state.metrics["metric 0"] = v handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=str(dirname / "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "training metric 0") for t, s in zip([1000.0] + true_values, stored_values): assert t == s.value
def _test(filename_prefix, to_save, obj, name): save_handler = MagicMock() save_handler.remove = MagicMock() checkpointer = Checkpoint( to_save, save_handler=save_handler, filename_prefix=filename_prefix, global_step_transform=lambda e, _: e.state.epoch, ) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 if len(filename_prefix) > 0: filename_prefix += "_" save_handler.assert_called_with( obj, "{}{}_1.pth".format(filename_prefix, name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with( obj, "{}{}_12.pth".format(filename_prefix, name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}{}_1.pth".format( filename_prefix, name)) assert checkpointer.last_checkpoint == "{}{}_12.pth".format( filename_prefix, name)
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler, score_name="loss", score_function=lambda e: e.state.score) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1, score=-0.77) checkpointer(trainer) assert save_handler.call_count == 1 metadata = {"basename": name, "score_name": "loss", "priority": -0.77} save_handler.assert_called_with(obj, "{}_loss=-0.7700.pt".format(name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 trainer.state.score = -0.76 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = -0.76 save_handler.assert_called_with(obj, "{}_loss=-0.7600.pt".format(name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with( "{}_loss=-0.7700.pt".format(name)) assert checkpointer.last_checkpoint == "{}_loss=-0.7600.pt".format( name)
def _test_neptune_saver_integration(device): model = torch.nn.Module().to(device) to_save_serializable = {"model": model} mock_logger = None if idist.get_rank() == 0: mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_artifact = MagicMock() mock_logger.delete_artifacts = MagicMock() saver = NeptuneSaver(mock_logger) checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) if idist.get_rank() == 0: assert mock_logger.log_artifact.call_count == 2 assert mock_logger.delete_artifacts.call_count == 1
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler) assert checkpointer.last_checkpoint is None trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpointer(trainer) assert save_handler.call_count == 1 metadata = {"basename": name, "score_name": None, "priority": 0} save_handler.assert_called_with(obj, "{}_0.pt".format(name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = 1234 save_handler.assert_called_with(obj, "{}_1234.pt".format(name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_0.pt".format(name)) assert checkpointer.last_checkpoint == "{}_1234.pt".format(name)
def test_load_checkpoint_with_different_num_classes(dirname): model = DummyPretrainedModel() to_save_single_object = {"model": model} trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) handler(trainer, to_save_single_object) fname = handler.last_checkpoint loaded_checkpoint = torch.load(fname) to_load_single_object = {"pretrained_features": model.features} with pytest.raises(RuntimeError): Checkpoint.load_objects(to_load_single_object, loaded_checkpoint) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) Checkpoint.load_objects(to_load_single_object, loaded_checkpoint, strict=False, blah="blah") loaded_weights = to_load_single_object["pretrained_features"].state_dict( )["weight"] assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
def test_base_output_handler_setup_output_metrics(): engine = Engine(lambda engine, batch: None) true_metrics = {"a": 0, "b": 1} engine.state = State(metrics=true_metrics) engine.state.output = 12345 # Only metric_names handler = DummyOutputHandler("tag", metric_names=['a', 'b'], output_transform=None) metrics = handler._setup_output_metrics(engine=engine) assert metrics == true_metrics # Only metric_names with a warning handler = DummyOutputHandler("tag", metric_names=['a', 'c'], output_transform=None) with pytest.warns(UserWarning): metrics = handler._setup_output_metrics(engine=engine) assert metrics == {"a": 0} # Only output as "output" handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: x) metrics = handler._setup_output_metrics(engine=engine) assert metrics == {"output": engine.state.output} # Only output as "loss" handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: {"loss": x}) metrics = handler._setup_output_metrics(engine=engine) assert metrics == {"loss": engine.state.output} # Metrics and output handler = DummyOutputHandler("tag", metric_names=['a', 'b'], output_transform=lambda x: {"loss": x}) metrics = handler._setup_output_metrics(engine=engine) assert metrics == {"a": 0, "b": 1, "loss": engine.state.output}
def _test(to_save, obj, name, score_name=None): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, score_name=score_name, score_function=lambda e: e.state.epoch ) if score_name is None: score_name = "" else: score_name += "=" trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 metadata = {"basename": name, "score_name": score_name[:-1] if len(score_name) > 0 else None, "priority": 1} save_handler.assert_called_with(obj, "{}_{}1.pt".format(name, score_name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = 12 save_handler.assert_called_with(obj, "{}_{}12.pt".format(name, score_name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_{}1.pt".format(name, score_name)) assert checkpointer.last_checkpoint == "{}_{}12.pt".format(name, score_name)
def _test(filename_prefix, to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, filename_prefix=filename_prefix, global_step_transform=lambda e, _: e.state.epoch, ) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=2, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 if len(filename_prefix) > 0: filename_prefix += "_" metadata = {"basename": "{}{}".format(filename_prefix, name), "score_name": None, "priority": 2} save_handler.assert_called_with(obj, "{}{}_2.pt".format(filename_prefix, name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = 12 save_handler.assert_called_with(obj, "{}{}_12.pt".format(filename_prefix, name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}{}_2.pt".format(filename_prefix, name)) assert checkpointer.last_checkpoint == "{}{}_12.pt".format(filename_prefix, name)
def test_checkpoint_last_checkpoint_on_score(): save_handler = MagicMock() save_handler.remove = MagicMock() to_save = {"model": DummyModel()} checkpointer = Checkpoint( to_save, save_handler=save_handler, n_saved=None, score_name="val_acc", score_function=lambda e: e.state.metrics["val_acc"], ) trainer = Engine(lambda e, b: None) val_acc = 0.0 for i in range(10): val_acc = i * 0.1 trainer.state = State(epoch=1, iteration=i, metrics={"val_acc": val_acc}) checkpointer(trainer) assert save_handler.call_count == 10 assert checkpointer.last_checkpoint == "{}_val_acc=0.9000.pth".format( "model")
def _test(ext, require_empty, archived): previous_fname = os.path.join( dirname, "{}_{}_{}{}".format(_PREFIX, "obj", 1, ext)) with open(previous_fname, "w") as f: f.write("test") h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=require_empty, archived=archived) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=1) model = DummyModel() to_save = {"model": model} h(engine, to_save) fname = h.last_checkpoint ext = ".pth.tar" if archived else ".pth" assert isinstance(fname, str) assert os.path.join(dirname, "{}_{}_{}{}".format(_PREFIX, "model", 1, ext)) == fname assert os.path.exists(fname) assert os.path.exists(previous_fname) loaded_objects = torch.load(fname) assert loaded_objects == model.state_dict() os.remove(fname)
def test_best_k(dirname): scores = iter([1.2, -2.0, 3.1, -4.0]) def score_function(_): return next(scores) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {"model": model} for _ in range(4): h(engine, to_save) expected = [ "{}_{}_{:.4f}.pth".format(_PREFIX, "model", i) for i in [1.2, 3.1] ] assert sorted(os.listdir(dirname)) == expected
def test_trains_disk_saver_integration_no_logger(): model = torch.nn.Module() to_save_serializable = {"model": model} with pytest.warns( UserWarning, match="TrainsSaver created a temporary checkpoints directory"): trains.Task.current_task = Mock(return_value=object()) trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock( ) trains_saver = TrainsSaver() checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=trains_saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) if trains_saver._atomic: assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2 else: saved_files = list(os.listdir(trains_saver.dirname)) assert len(saved_files) == 1 assert saved_files[0] == "model_1.pt"
def test_best_k_with_suffix(dirname): scores = [0.3456789, 0.1234, 0.4567, 0.134567] scores_iter = iter(scores) def score_function(engine): return next(scores_iter) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, score_name="val_loss") engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {"model": model} for _ in range(4): engine.state.epoch += 1 h(engine, to_save) expected = [ "{}_{}_val_loss={:.4}.pth".format(_PREFIX, "model", scores[e - 1]) for e in [1, 3] ] assert sorted(os.listdir(dirname)) == expected
def _test( to_save, filename_prefix="", score_function=None, score_name=None, global_step_transform=None, filename_pattern=None, ): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, filename_prefix=filename_prefix, score_function=score_function, score_name=score_name, global_step_transform=global_step_transform, filename_pattern=filename_pattern, ) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=12, iteration=203, score=0.9999) checkpointer(trainer) return checkpointer.last_checkpoint
def test_base_output_handler_with_another_engine(): engine = Engine(lambda engine, batch: None) true_metrics = {"a": 0, "b": 1} engine.state = State(metrics=true_metrics) engine.state.output = 12345 with pytest.warns(DeprecationWarning, match="Use of another_engine is deprecated"): handler = DummyOutputHandler("tag", metric_names=['a', 'b'], output_transform=None, another_engine=engine)
def test__setup_engine(): engine = Engine(lambda e, b: 1) engine.state = State(iteration=10, epoch=1, max_epochs=100, epoch_length=100) data = list(range(100)) engine.state.dataloader = data engine._setup_engine() assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) trainer = Engine(lambda e, b: None) evaluator = Engine(lambda e, b: None) trainer.state = State(epoch=11, iteration=1) checkpointer = Checkpoint( to_save, save_handler=save_handler, global_step_transform=lambda _1, _2: trainer.state.epoch, score_name="val_acc", score_function=lambda e: e.state.metrics["val_acc"], ) evaluator.state = State(epoch=1, iteration=1000, metrics={"val_acc": 0.77}) checkpointer(evaluator) assert save_handler.call_count == 1 metadata = { "basename": name, "score_name": "val_acc", "priority": 0.77 } save_handler.assert_called_with(obj, "{}_11_val_acc=0.7700.pt".format(name), metadata) trainer.state.epoch = 12 evaluator.state.metrics["val_acc"] = 0.78 checkpointer(evaluator) assert save_handler.call_count == 2 metadata["priority"] = 0.78 save_handler.assert_called_with(obj, "{}_12_val_acc=0.7800.pt".format(name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with( "{}_11_val_acc=0.7700.pt".format(name)) assert checkpointer.last_checkpoint == "{}_12_val_acc=0.7800.pt".format( name)
def test_base_output_handler_setup_output_state_attrs(): engine = Engine(lambda engine, batch: None) true_metrics = {"a": 0, "b": 1} engine.state = State(metrics=true_metrics) engine.state.alpha = 3.899 engine.state.beta = torch.tensor(5.499) engine.state.gamma = torch.tensor([2106.0, 6.0]) engine.state.output = 12345 # Only State Attributes handler = DummyOutputHandler(tag="tag", metric_names=None, output_transform=None, state_attributes=["alpha", "beta", "gamma"]) state_attrs = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert state_attrs == { "tag/alpha": 3.899, "tag/beta": torch.tensor(5.499), "tag/gamma/0": 2106.0, "tag/gamma/1": 6.0, } # Metrics and Attributes handler = DummyOutputHandler(tag="tag", metric_names=["a", "b"], output_transform=None, state_attributes=["alpha", "beta", "gamma"]) state_attrs = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert state_attrs == { "tag/a": 0, "tag/b": 1, "tag/alpha": 3.899, "tag/beta": torch.tensor(5.499), "tag/gamma/0": 2106.0, "tag/gamma/1": 6.0, } # Metrics, Attributes and output handler = DummyOutputHandler( tag="tag", metric_names="all", output_transform=lambda x: {"loss": x}, state_attributes=["alpha", "beta", "gamma"], ) state_attrs = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert state_attrs == { "tag/a": 0, "tag/b": 1, "tag/alpha": 3.899, "tag/beta": torch.tensor(5.499), "tag/gamma/0": 2106.0, "tag/gamma/1": 6.0, "tag/loss": engine.state.output, }
def test_checkpoint_load_objects_from_saved_file(dirname): def _get_single_obj_to_save(): model = DummyModel() to_save = { "model": model, } return to_save def _get_multiple_objs_to_save(): model = DummyModel() optim = torch.optim.SGD(model.parameters(), lr=0.001) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5) to_save = { "model": model, "optimizer": optim, "lr_scheduler": lr_scheduler, } return to_save trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) # case: multiple objects handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_multiple_objs_to_save() handler(trainer, to_save) fname = handler.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) Checkpoint.load_objects(to_save, loaded_objects) os.remove(fname) # case: saved multiple objects, loaded single object handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_multiple_objs_to_save() handler(trainer, to_save) fname = handler.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) to_load = {'model': to_save['model']} Checkpoint.load_objects(to_load, loaded_objects) os.remove(fname) # case: single object handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_single_obj_to_save() handler(trainer, to_save) fname = handler.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) Checkpoint.load_objects(to_save, loaded_objects)
def test_checkpoint_score_function_wrong_output(): model = DummyModel() to_save = {'model': model} checkpointer = Checkpoint(to_save, lambda x: x, score_function=lambda e: {"1": 1}, score_name="acc") trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) with pytest.raises(ValueError, match=r"Output of score_function should be a number"): checkpointer(trainer)
def test__setup_engine(): engine = Engine(lambda e, b: 1) engine.state = State( iteration=10, epoch=1, max_epochs=100, epoch_length=100, seed=12 ) data = list(range(100)) engine.state.dataloader = data engine._setup_engine() assert engine._dataloader_len == len(data)
def test_base_output_handler_setup_output_metrics(): engine = Engine(lambda engine, batch: None) true_metrics = {"a": 0, "b": 1} engine.state = State(metrics=true_metrics) engine.state.output = 12345 # Only metric_names handler = DummyOutputHandler("tag", metric_names=["a", "b"], output_transform=None) metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert metrics == {"tag/a": 0, "tag/b": 1} # Only metric_names with a warning handler = DummyOutputHandler("tag", metric_names=["a", "c"], output_transform=None) with pytest.warns(UserWarning): metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert metrics == {"tag/a": 0} # Only output as "output" handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: x) metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert metrics == {"tag/output": engine.state.output} # Only output as "loss" handler = DummyOutputHandler("tag", metric_names=None, output_transform=lambda x: {"loss": x}) metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert metrics == {"tag/loss": engine.state.output} # Metrics and output handler = DummyOutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert metrics == {"tag/a": 0, "tag/b": 1, "tag/loss": engine.state.output} # All metrics handler = DummyOutputHandler("tag", metric_names="all", output_transform=None) metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False) assert metrics == {"tag/a": 0, "tag/b": 1}
def test_checkpoint_save_handler_callable(): def save_handler(c, f): assert f == "model_12.pt" to_save = {"model": DummyModel()} checkpointer = Checkpoint(to_save, save_handler=save_handler,) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=12) checkpointer(trainer)
def _test(to_save, obj, name): save_handler = MagicMock() save_handler.remove = MagicMock() trainer = Engine(lambda e, b: None) evaluator = Engine(lambda e, b: None) trainer.state = State(epoch=11, iteration=1) checkpointer = Checkpoint( to_save, save_handler=save_handler, global_step_transform=lambda _1, _2: trainer.state.epoch, score_name="val_acc", score_function=lambda e: e.state.metrics["val_acc"], ) evaluator.state = State(epoch=1, iteration=1000, metrics={"val_acc": 0.77}) checkpointer(evaluator) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_11_val_acc=0.77.pth".format(name)) trainer.state.epoch = 12 evaluator.state.metrics["val_acc"] = 0.78 checkpointer(evaluator) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_12_val_acc=0.78.pth".format(name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with( "{}_11_val_acc=0.77.pth".format(name)) assert checkpointer.last_checkpoint == "{}_12_val_acc=0.78.pth".format( name)
def test_checkpoint_last_checkpoint(): save_handler = MagicMock(spec=BaseSaveHandler) to_save = {"model": DummyModel()} checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None) trainer = Engine(lambda e, b: None) for i in range(10): trainer.state = State(epoch=1, iteration=i) checkpointer(trainer) assert save_handler.call_count == 10 assert checkpointer.last_checkpoint == "{}_9.pt".format("model")
def test_checkpoint_last_checkpoint(): save_handler = MagicMock() save_handler.__call__ = MagicMock() model = DummyModel() to_save = {'model': model} checkpointer = Checkpoint(to_save, save_handler=save_handler) assert checkpointer.last_checkpoint is None engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) checkpointer(engine) assert checkpointer.last_checkpoint == "model_0.pth"
def test_valid_state_dict_save(dirname): model = DummyModel() h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) to_save = {"name": 42} with pytest.raises(TypeError, match=r"should have `state_dict` method"): h(engine, to_save) to_save = {"name": model} try: h(engine, to_save) except ValueError: pytest.fail("Unexpected ValueError")