def test_output_handler_output_transform(dirname): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "tag/output" in wrapper.windows assert wrapper.windows["tag/output"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/output"]["opts"], name="tag/output", ) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "another_tag/loss" in wrapper.windows assert wrapper.windows["another_tag/loss"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["another_tag/loss"]["opts"], name="another_tag/loss", )
def test_output_handler_with_global_step_transform(): def global_step_transform(*args, **kwargs): return 10 wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 1 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls([ call( X=[10], Y=[12345], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ])
def test_output_handler_with_global_step_from_engine(): mock_another_engine = MagicMock() mock_another_engine.state = State() mock_another_engine.state.epoch = 10 mock_another_engine.state.output = 12.345 wrapper = OutputHandler( "tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_from_engine(mock_another_engine), ) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 1 mock_engine.state.output = 0.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 1 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls([ call( X=[mock_another_engine.state.epoch], Y=[mock_engine.state.output], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ]) mock_another_engine.state.epoch = 11 mock_engine.state.output = 1.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 2 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls([ call( X=[mock_another_engine.state.epoch], Y=[mock_engine.state.output], env=mock_logger.vis.env, win=wrapper.windows["tag/loss"]["win"], update="append", opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ])
def test_output_handler_with_wrong_logger_type(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises( RuntimeError, match="Handler 'OutputHandler' works only with VisdomLogger"): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_output_handler_with_wrong_global_step_transform_output(): def global_step_transform(*args, **kwargs): return "a" wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 with pytest.raises(TypeError, match="global_step must be int"): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
def test_integration_with_executor_as_context_manager(visdom_server, visdom_server_stop): n_epochs = 5 data = list(range(50)) losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) with VisdomLogger(server=visdom_server[0], port=visdom_server[1], num_workers=1) as vd_logger: # close all windows in 'main' environment vd_logger.vis.close() trainer = Engine(update_fn) output_handler = OutputHandler(tag="training", output_transform=lambda x: {"loss": x}) vd_logger.attach(trainer, log_handler=output_handler, event_name=Events.ITERATION_COMPLETED) trainer.run(data, max_epochs=n_epochs) assert len(output_handler.windows) == 1 assert "training/loss" in output_handler.windows win_name = output_handler.windows["training/loss"]["win"] data = vd_logger.vis.get_window_data(win=win_name) data = _parse_content(data) assert "content" in data and "data" in data["content"] data = data["content"]["data"][0] assert "x" in data and "y" in data x_vals, y_vals = data["x"], data["y"] assert all([ int(x) == x_true for x, x_true in zip( x_vals, list(range(1, n_epochs * len(data) + 1))) ]) assert all([y == y_true for y, y_true in zip(y_vals, losses)])
def test_output_handler_both(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 3 assert (len(wrapper.windows) == 3 and "tag/a" in wrapper.windows and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows) assert wrapper.windows["tag/a"]["win"] is not None assert wrapper.windows["tag/b"]["win"] is not None assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[ 5, ], Y=[ 12.23, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a"]["opts"], name="tag/a", ), call( X=[ 5, ], Y=[ 23.45, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/b"]["opts"], name="tag/b", ), call( X=[ 5, ], Y=[ 12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ), ], any_order=True, ) mock_engine.state.epoch = 6 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 6 assert (len(wrapper.windows) == 3 and "tag/a" in wrapper.windows and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows) assert wrapper.windows["tag/a"]["win"] is not None assert wrapper.windows["tag/b"]["win"] is not None assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[ 6, ], Y=[ 12.23, ], env=mock_logger.vis.env, win=wrapper.windows["tag/a"]["win"], update="append", opts=wrapper.windows["tag/a"]["opts"], name="tag/a", ), call( X=[ 6, ], Y=[ 23.45, ], env=mock_logger.vis.env, win=wrapper.windows["tag/b"]["win"], update="append", opts=wrapper.windows["tag/b"]["opts"], name="tag/b", ), call( X=[ 6, ], Y=[ 12345, ], env=mock_logger.vis.env, win=wrapper.windows["tag/loss"]["win"], update="append", opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ), ], any_order=True, )
def test_output_handler_metric_names(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"]) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len( wrapper.windows ) == 2 and "tag/a" in wrapper.windows and "tag/b" in wrapper.windows assert wrapper.windows["tag/a"]["win"] is not None assert wrapper.windows["tag/b"]["win"] is not None assert mock_logger.vis.line.call_count == 2 mock_logger.vis.line.assert_has_calls( [ call( X=[ 5, ], Y=[ 12.23, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a"]["opts"], name="tag/a", ), call( X=[ 5, ], Y=[ 23.45, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/b"]["opts"], name="tag/b", ), ], any_order=True, ) wrapper = OutputHandler("tag", metric_names=[ "a", ]) mock_engine = MagicMock() mock_engine.state = State( metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])}) mock_engine.state.iteration = 5 mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 4 and all( [f"tag/a/{i}" in wrapper.windows for i in range(4)]) assert wrapper.windows["tag/a/0"]["win"] is not None assert wrapper.windows["tag/a/1"]["win"] is not None assert wrapper.windows["tag/a/2"]["win"] is not None assert wrapper.windows["tag/a/3"]["win"] is not None assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls( [ call( X=[ 5, ], Y=[ 0.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a/0"]["opts"], name="tag/a/0", ), call( X=[ 5, ], Y=[ 1.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a/1"]["opts"], name="tag/a/1", ), call( X=[ 5, ], Y=[ 2.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a/2"]["opts"], name="tag/a/2", ), call( X=[ 5, ], Y=[ 3.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a/3"]["opts"], name="tag/a/3", ), ], any_order=True, ) wrapper = OutputHandler("tag", metric_names=["a", "c"]) mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) mock_engine.state.iteration = 7 mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() with pytest.warns(UserWarning): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "tag/a" in wrapper.windows assert wrapper.windows["tag/a"]["win"] is not None assert mock_logger.vis.line.call_count == 1 mock_logger.vis.line.assert_has_calls( [ call( X=[ 7, ], Y=[ 55.56, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a"]["opts"], name="tag/a", ), ], any_order=True, ) # all metrics wrapper = OutputHandler("tag", metric_names="all") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len( wrapper.windows ) == 2 and "tag/a" in wrapper.windows and "tag/b" in wrapper.windows assert wrapper.windows["tag/a"]["win"] is not None assert wrapper.windows["tag/b"]["win"] is not None assert mock_logger.vis.line.call_count == 2 mock_logger.vis.line.assert_has_calls( [ call( X=[ 5, ], Y=[ 12.23, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/a"]["opts"], name="tag/a", ), call( X=[ 5, ], Y=[ 23.45, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/b"]["opts"], name="tag/b", ), ], any_order=True, )
def test_output_handler_state_attrs(): wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 5 mock_engine.state.alpha = 3.899 mock_engine.state.beta = torch.tensor(12.0) mock_engine.state.gamma = torch.tensor([21.0, 6.0]) wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.vis.line.call_count == 4 assert (len(wrapper.windows) == 4 and "tag/alpha" in wrapper.windows and "tag/beta" in wrapper.windows and "tag/gamma/0" in wrapper.windows and "tag/gamma/1" in wrapper.windows) assert wrapper.windows["tag/alpha"]["win"] is not None assert wrapper.windows["tag/beta"]["win"] is not None assert wrapper.windows["tag/gamma/0"]["win"] is not None assert wrapper.windows["tag/gamma/1"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[5], Y=[3.899], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/alpha"]["opts"], name="tag/alpha", ), call( X=[5], Y=[12.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/beta"]["opts"], name="tag/beta", ), call( X=[5], Y=[21.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/gamma/0"]["opts"], name="tag/gamma/0", ), call( X=[5], Y=[6.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/gamma/1"]["opts"], name="tag/gamma/1", ), ], any_order=True, )