def test_output_handler_with_wrong_logger_type(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises(TypeError, match="Handler OutputHandler works only with NeptuneLogger"): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_output_handler_with_global_step_from_engine(): mock_another_engine = MagicMock() mock_another_engine.state = State() mock_another_engine.state.epoch = 10 mock_another_engine.state.output = 12.345 wrapper = OutputHandler( "tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_from_engine(mock_another_engine), ) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 1 mock_engine.state.output = 0.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metric.call_count == 1 mock_logger.log_metric.assert_has_calls( [call("tag/loss", y=mock_engine.state.output, x=mock_another_engine.state.epoch)] ) mock_another_engine.state.epoch = 11 mock_engine.state.output = 1.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metric.call_count == 2 mock_logger.log_metric.assert_has_calls( [call("tag/loss", y=mock_engine.state.output, x=mock_another_engine.state.epoch)] )
def test_output_handler_output_transform(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metric.assert_called_once_with("tag/output", y=12345, x=123) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metric.assert_called_once_with("another_tag/loss", y=12345, x=123)
def test_output_handler_with_wrong_global_step_transform_output(): def global_step_transform(*args, **kwargs): return "a" wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=NeptuneLogger) mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 with pytest.raises(TypeError, match="global_step must be int"): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
def test_output_handler_with_global_step_transform(): def global_step_transform(*args, **kwargs): return 10 wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metric.call_count == 1 mock_logger.log_metric.assert_has_calls([call("tag/loss", y=12345, x=10)])
def test_output_handler_both(): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metric.call_count == 3 mock_logger.log_metric.assert_has_calls( [call("tag/a", y=12.23, x=5), call("tag/b", y=23.45, x=5), call("tag/loss", y=12345, x=5)], any_order=True )
def test_output_handler_state_attrs(): wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 5 mock_engine.state.alpha = 3.899 mock_engine.state.beta = torch.tensor(12.23) mock_engine.state.gamma = torch.tensor([21.0, 6.0]) wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metric.call_count == 4 mock_logger.log_metric.assert_has_calls( [ call("tag/alpha", y=3.899, x=5), call("tag/beta", y=torch.tensor(12.23).item(), x=5), call("tag/gamma/0", y=21.0, x=5), call("tag/gamma/1", y=6.0, x=5), ], any_order=True, )
def test_output_handler_metric_names(): wrapper = OutputHandler("tag", metric_names=["a", "b"]) mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metric.call_count == 2 mock_logger.log_metric.assert_has_calls( [call("tag/a", y=12.23, x=5), call("tag/b", y=23.45, x=5)], any_order=True) wrapper = OutputHandler("tag", metric_names=["a"]) mock_engine = MagicMock() mock_logger.log_metric = MagicMock() mock_engine.state = State( metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])}) mock_engine.state.iteration = 5 mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metric.call_count == 4 mock_logger.log_metric.assert_has_calls( [ call("tag/a/0", y=0.0, x=5), call("tag/a/1", y=1.0, x=5), call("tag/a/2", y=2.0, x=5), call("tag/a/3", y=3.0, x=5), ], any_order=True, ) wrapper = OutputHandler("tag", metric_names=["a", "c"]) mock_engine = MagicMock() mock_logger.log_metric = MagicMock() mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) mock_engine.state.iteration = 7 mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() with pytest.warns(UserWarning): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metric.call_count == 1 mock_logger.log_metric.assert_has_calls([call("tag/a", y=55.56, x=7)], any_order=True) # all metrics wrapper = OutputHandler("tag", metric_names="all") mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metric.call_count == 2 mock_logger.log_metric.assert_has_calls( [call("tag/a", y=12.23, x=5), call("tag/b", y=23.45, x=5)], any_order=True) # log a torch tensor (ndimension = 0) wrapper = OutputHandler("tag", metric_names="all") mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_metric = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={ "a": torch.tensor(12.23), "b": torch.tensor(23.45) }) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metric.call_count == 2 mock_logger.log_metric.assert_has_calls( [ call("tag/a", y=torch.tensor(12.23).item(), x=5), call("tag/b", y=torch.tensor(23.45).item(), x=5) ], any_order=True, )