def test_output_handler_output_transform(dirname): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.writer.add_scalar.assert_called_once_with("tag/output", 12345, 123) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.writer.add_scalar.assert_called_once_with("another_tag/loss", 12345, 123)
def test_grads_hist_handler(dummy_model_factory): model = dummy_model_factory(with_grads=True, with_frozen_layer=False) wrapper = GradsHistHandler(model) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.writer.add_histogram.call_count == 4 mock_logger.writer.add_histogram.assert_has_calls([ call(tag="grads/fc1/weight", values=ANY, global_step=5), call(tag="grads/fc1/bias", values=ANY, global_step=5), call(tag="grads/fc2/weight", values=ANY, global_step=5), call(tag="grads/fc2/bias", values=ANY, global_step=5), ], any_order=True)
def _test(tag=None): wrapper = WeightsScalarHandler(model, tag=tag) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.experiment = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) tag_prefix = "{}/".format(tag) if tag else "" assert mock_logger.experiment.log_metric.call_count == 4 mock_logger.experiment.log_metric.assert_has_calls([ call(tag_prefix + "weights_norm/fc1/weight", y=0.0, x=5), call(tag_prefix + "weights_norm/fc1/bias", y=0.0, x=5), call(tag_prefix + "weights_norm/fc2/weight", y=12.0, x=5), call(tag_prefix + "weights_norm/fc2/bias", y=math.sqrt(12.0), x=5), ], any_order=True)
def test_no_grad(): y_pred = torch.zeros(4, requires_grad=True) y = torch.zeros(4, requires_grad=False) class DummyMetric(Metric): def reset(self): pass def compute(self): pass def update(self, output): y_pred, y = output mse = torch.pow(y_pred - y.view_as(y_pred), 2) assert y_pred.requires_grad assert not mse.requires_grad metric = DummyMetric() state = State(output=(y_pred, y)) engine = MagicMock(state=state) metric.iteration_completed(engine)
def test_neptune_saver_integration(): model = torch.nn.Module() to_save_serializable = {"model": model} mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_artifact = MagicMock() mock_logger.delete_artifacts = MagicMock() saver = NeptuneSaver(mock_logger) checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) assert mock_logger.log_artifact.call_count == 2 assert mock_logger.delete_artifacts.call_count == 1
def test_output_handler_output_transform(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=PolyaxonLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with(step=123, **{"tag/output": 12345}) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=PolyaxonLogger) mock_logger.log_metrics = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with(step=123, **{"another_tag/loss": 12345})
def _test_checkpoint_with_ddp(device): torch.manual_seed(0) model = DummyModel().to(device) device_ids = (None if "cpu" in device.type else [ device, ]) ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) to_save = {"model": ddp_model} save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpointer(trainer) assert save_handler.call_count == 1 metadata = {"basename": "model", "score_name": None, "priority": 0} save_handler.assert_called_with(model.state_dict(), "model_0.pt", metadata)
def test_output_handler_both(): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.experiment = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.experiment.log_metric.call_count == 3 mock_logger.experiment.log_metric.assert_has_calls([ call("tag/a", y=12.23, x=5), call("tag/b", y=23.45, x=5), call("tag/loss", y=12345, x=5) ], any_order=True)
def test_weights_scalar_handler(dummy_model_factory): model = dummy_model_factory(with_grads=True, with_frozen_layer=False) wrapper = WeightsScalarHandler(model) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.writer.add_scalar.call_count == 4 mock_logger.writer.add_scalar.assert_has_calls([ call("weights_norm/fc1/weight", 0.0, 5), call("weights_norm/fc1/bias", 0.0, 5), call("weights_norm/fc2/weight", 12.0, 5), call("weights_norm/fc2/bias", math.sqrt(12.0), 5), ], any_order=True)
def _test(tag=None): wrapper = WeightsScalarHandler(model, tag=tag) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) tag_prefix = "{}/".format(tag) if tag else "" assert mock_logger.writer.add_scalar.call_count == 4 mock_logger.writer.add_scalar.assert_has_calls([ call(tag_prefix + "weights_norm/fc1/weight", 0.0, 5), call(tag_prefix + "weights_norm/fc1/bias", 0.0, 5), call(tag_prefix + "weights_norm/fc2/weight", 12.0, 5), call(tag_prefix + "weights_norm/fc2/bias", math.sqrt(12.0), 5), ], any_order=True)
def test_output_handler_output_transform(dirname): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "tag/output" in wrapper.windows assert wrapper.windows["tag/output"]['win'] is not None mock_logger.vis.line.assert_called_once_with( X=[123, ], Y=[12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/output']['opts'], name="tag/output" ) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "another_tag/loss" in wrapper.windows assert wrapper.windows["another_tag/loss"]['win'] is not None mock_logger.vis.line.assert_called_once_with( X=[123, ], Y=[12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['another_tag/loss']['opts'], name="another_tag/loss" )
def _test_gpu_info(device="cpu"): gpu_info = GpuInfo() # increase code cov gpu_info.reset() gpu_info.update(None) t = torch.rand(4, 10, 100, 100).to(device) data = gpu_info.compute() assert len(data) > 0 assert "fb_memory_usage" in data[0] mem_report = data[0]["fb_memory_usage"] assert "used" in mem_report and "total" in mem_report assert mem_report["total"] > 0.0 assert mem_report["used"] > t.shape[0] * t.shape[1] * t.shape[2] * t.shape[ 3] / 1024.0 / 1024.0 assert "utilization" in data[0] util_report = data[0]["utilization"] assert "gpu_util" in util_report # with Engine engine = Engine(lambda engine, batch: 0.0) engine.state = State(metrics={}) gpu_info.completed(engine, name="gpu") assert "gpu:0 mem(%)" in engine.state.metrics assert isinstance(engine.state.metrics["gpu:0 mem(%)"], int) assert int(mem_report["used"] * 100.0 / mem_report["total"]) == engine.state.metrics["gpu:0 mem(%)"] if util_report["gpu_util"] != "N/A": assert "gpu:0 util(%)" in engine.state.metrics assert isinstance(engine.state.metrics["gpu:0 util(%)"], int) assert int( util_report["gpu_util"]) == engine.state.metrics["gpu:0 util(%)"] else: assert "gpu:0 util(%)" not in engine.state.metrics
def test_optimizer_params(): optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) # mock_logger.vis.line.assert_called_once_with("lr/group_0", 0.01, 123) assert len(wrapper.windows) == 1 and "lr/group_0" in wrapper.windows assert wrapper.windows["lr/group_0"]['win'] is not None mock_logger.vis.line.assert_called_once_with( X=[123, ], Y=[0.01, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['lr/group_0']['opts'], name="lr/group_0" ) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr", tag="generator") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "generator/lr/group_0" in wrapper.windows assert wrapper.windows["generator/lr/group_0"]['win'] is not None mock_logger.vis.line.assert_called_once_with( X=[123, ], Y=[0.01, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['generator/lr/group_0']['opts'], name="generator/lr/group_0" )
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler) assert checkpointer.last_checkpoint is None trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpointer(trainer) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_0.pt".format(name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_1234.pt".format(name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_0.pt".format(name)) assert checkpointer.last_checkpoint == "{}_1234.pt".format(name)
def test_completed_on_cuda(): # Checks https://github.com/pytorch/ignite/issues/1635#issuecomment-863026919 class DummyMetric(Metric): def reset(self): pass def compute(self): return torch.tensor([1.0, 2.0, 3.0], device="cuda") def update(self, output): pass m = DummyMetric() # tensor engine = MagicMock(state=State(metrics={})) m.completed(engine, "metric") assert "metric" in engine.state.metrics assert isinstance(engine.state.metrics["metric"], torch.Tensor) assert engine.state.metrics["metric"].device.type == "cpu"
def test_output_handler_both(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.writer.add_scalar.call_count == 3 mock_logger.writer.add_scalar.assert_has_calls([ call("tag/a", 12.23, 5), call("tag/b", 23.45, 5), call("tag/loss", 12345, 5) ], any_order=True)
def _test(filename_prefix, to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, filename_prefix=filename_prefix, global_step_transform=lambda e, _: e.state.epoch, ) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 if len(filename_prefix) > 0: filename_prefix += "_" metadata = { "basename": "{}{}".format(filename_prefix, name), "score_name": None, "priority": 1 } save_handler.assert_called_with( obj, "{}{}_1.pt".format(filename_prefix, name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = 1234 save_handler.assert_called_with( obj, "{}{}_12.pt".format(filename_prefix, name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}{}_1.pt".format( filename_prefix, name)) assert checkpointer.last_checkpoint == "{}{}_12.pt".format( filename_prefix, name)
def test_trains_disk_saver_integration_no_logger(): model = torch.nn.Module() to_save_serializable = {"model": model} with pytest.warns(UserWarning, match="TrainsSaver created a temporary checkpoints directory"): trains.Task.current_task = Mock(return_value=object()) trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock() trains_saver = TrainsSaver() checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=trains_saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) if trains_saver._atomic: assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2 else: saved_files = list(os.listdir(trains_saver.dirname)) assert len(saved_files) == 1 assert saved_files[0] == "model_1.pt"
def test_output_handler_both(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=TrainsLogger) mock_logger.trains_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.trains_logger.report_scalar.call_count == 3 mock_logger.trains_logger.report_scalar.assert_has_calls( [ call(title="tag", series="a", iteration=5, value=12.23), call(title="tag", series="b", iteration=5, value=23.45), call(title="tag", series="loss", iteration=5, value=12345), ], any_order=True, )
def test_mlflow_bad_metric_name_handling(dirname): import mlflow true_values = [123.0, 23.4, 333.4] with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger: active_run = mlflow.active_run() handler = OutputHandler(tag="training", metric_names="all") engine = Engine(lambda e, b: None) engine.state = State(metrics={ "metric:0 in %": 123.0, "metric 0": 1000.0 }) with pytest.warns( UserWarning, match= r"MLflowLogger output_handler encountered an invalid metric name" ): engine.state.epoch = 1 handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) for _, v in enumerate(true_values): engine.state.epoch += 1 engine.state.metrics["metric 0"] = v handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "training metric 0") for t, s in zip([1000.0] + true_values, stored_values): assert t == s.value
def test_output_handler_both(): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=PolyaxonLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metrics.call_count == 1 mock_logger.log_metrics.assert_called_once_with(step=5, **{ "tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345 })
def test_transform(): y_pred = torch.Tensor([[2.0], [-2.0]]) y = torch.zeros(2) class DummyMetric(Metric): def reset(self): pass def compute(self): pass def update(self, output): assert output == (y_pred, y) def transform(output): pred_dict, target_dict = output return pred_dict['y'], target_dict['y'] metric = DummyMetric(output_transform=transform) state = State(output=({'y': y_pred}, {'y': y})) engine = MagicMock(state=state) metric.iteration_completed(engine)
def test_checkpoint_last_checkpoint_on_score(): save_handler = MagicMock(spec=BaseSaveHandler) to_save = {"model": DummyModel()} checkpointer = Checkpoint( to_save, save_handler=save_handler, n_saved=None, score_name="val_acc", score_function=lambda e: e.state.metrics["val_acc"], ) trainer = Engine(lambda e, b: None) val_acc = 0.0 for i in range(10): val_acc = i * 0.1 trainer.state = State(epoch=1, iteration=i, metrics={"val_acc": val_acc}) checkpointer(trainer) assert save_handler.call_count == 10 assert checkpointer.last_checkpoint == "{}_val_acc=0.9000.pt".format("model")
def test_output_handler_state_attrs(): wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 5 mock_engine.state.alpha = 3.899 mock_engine.state.beta = torch.tensor(12.21) mock_engine.state.gamma = torch.tensor([21.0, 6.0]) wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with( { "tag alpha": 3.899, "tag beta": torch.tensor(12.21).item(), "tag gamma 0": 21.0, "tag gamma 1": 6.0 }, step=5)
def test_best_k_with_suffix(dirname): scores = [0.3456789, 0.1234, 0.4567, 0.134567] scores_iter = iter(scores) def score_function(engine): return next(scores_iter) h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, score_name="val_loss") engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=0) model = DummyModel() to_save = {'model': model} for _ in range(4): engine.state.epoch += 1 h(engine, to_save) expected = ['{}_{}_val_loss={:.4}.pth'.format(_PREFIX, 'model', scores[e - 1]) for e in [1, 3]] assert sorted(os.listdir(dirname)) == expected
def _test(ext, require_empty, archived): previous_fname = os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'obj', 1, ext)) with open(previous_fname, 'w') as f: f.write("test") h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=require_empty, archived=archived) engine = Engine(lambda e, b: None) engine.state = State(epoch=0, iteration=1) model = DummyModel() to_save = {'model': model} h(engine, to_save) fname = h.last_checkpoint ext = ".pth.tar" if archived else ".pth" assert isinstance(fname, str) assert os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'model', 1, ext)) == fname assert os.path.exists(fname) assert os.path.exists(previous_fname) loaded_objects = torch.load(fname) assert loaded_objects == model.state_dict() os.remove(fname)
def test_optimizer_params(): optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metric.assert_called_once_with("lr/group_0", y=0.01, x=123) wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator") mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metric.assert_called_once_with("generator/lr/group_0", y=0.01, x=123)
def test_grads_scalar_handler(dummy_model_factory, norm_mock): model = dummy_model_factory(with_grads=True, with_frozen_layer=False) wrapper = GradsScalarHandler(model, reduction=norm_mock) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 norm_mock.reset_mock() wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.writer.add_scalar.assert_has_calls([ call("grads_norm/fc1/weight", ANY, 5), call("grads_norm/fc1/bias", ANY, 5), call("grads_norm/fc2/weight", ANY, 5), call("grads_norm/fc2/bias", ANY, 5), ], any_order=True) assert mock_logger.writer.add_scalar.call_count == 4 assert norm_mock.call_count == 4
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler, include_self=True) assert checkpointer.last_checkpoint is None trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpointer(trainer) assert save_handler.call_count == 1 fname = "{}_0.pt".format(name) obj["checkpointer"] = OrderedDict([("saved", [(0, fname)])]) metadata = {"basename": name, "score_name": None, "priority": 0} save_handler.assert_called_with(obj, fname, metadata) # Swap object, state should be maintained checkpointer2 = Checkpoint(to_save, save_handler=save_handler, include_self=True) checkpointer2.load_state_dict(checkpointer.state_dict()) assert checkpointer2.last_checkpoint == fname trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer2(trainer) assert save_handler.call_count == 2 metadata["priority"] = 1234 # This delete only happens if state was restored correctly. save_handler.remove.assert_called_with("{}_0.pt".format(name)) fname = "{}_1234.pt".format(name) obj["checkpointer"] = OrderedDict([("saved", [(1234, fname)])]) save_handler.assert_called_with(obj, fname, metadata) assert save_handler.remove.call_count == 1 assert checkpointer2.last_checkpoint == fname
def test_output_handler_with_global_step_transform(): def global_step_transform(*args, **kwargs): return 10 wrapper = OutputHandler( "tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform, ) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 1 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls([ call( X=[ 10, ], Y=[ 12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ])