def test_output_handler_with_wrong_logger_type(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises( RuntimeError, match="Handler OutputHandler works only with ClearMLLogger"): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_output_handler_output_transform(dirname): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.clearml_logger.report_scalar.assert_called_once_with( iteration=123, series="output", title="tag", value=12345) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.clearml_logger.report_scalar.assert_called_once_with( iteration=123, series="loss", title="another_tag", value=12345)
def test_output_handler_with_wrong_global_step_transform_output(): def global_step_transform(*args, **kwargs): return "a" wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 with pytest.raises(TypeError, match="global_step must be int"): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
def test_output_handler_with_global_step_transform(): def global_step_transform(*args, **kwargs): return 10 wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 1 mock_logger.clearml_logger.report_scalar.assert_has_calls( [call(title="tag", series="loss", iteration=10, value=12345)])
def test_output_handler_with_global_step_from_engine(): mock_another_engine = MagicMock() mock_another_engine.state = State() mock_another_engine.state.epoch = 10 mock_another_engine.state.output = 12.345 wrapper = OutputHandler( "tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_from_engine(mock_another_engine), ) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 1 mock_engine.state.output = 0.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 1 mock_logger.clearml_logger.report_scalar.assert_has_calls([ call(title="tag", series="loss", iteration=mock_another_engine.state.epoch, value=mock_engine.state.output) ]) mock_another_engine.state.epoch = 11 mock_engine.state.output = 1.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 2 mock_logger.clearml_logger.report_scalar.assert_has_calls([ call(title="tag", series="loss", iteration=mock_another_engine.state.epoch, value=mock_engine.state.output) ])
def test_output_handler_both(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 3 mock_logger.clearml_logger.report_scalar.assert_has_calls( [ call(title="tag", series="a", iteration=5, value=12.23), call(title="tag", series="b", iteration=5, value=23.45), call(title="tag", series="loss", iteration=5, value=12345), ], any_order=True, )
def test_output_handler_state_attrs(): wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 5 mock_engine.state.alpha = 3.899 mock_engine.state.beta = torch.tensor(12.0) mock_engine.state.gamma = torch.tensor([21.0, 6.0]) wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 4 mock_logger.clearml_logger.report_scalar.assert_has_calls( [ call(title="tag", series="alpha", iteration=5, value=3.899), call(title="tag", series="beta", iteration=5, value=12.0), call(title="tag/gamma", series="0", iteration=5, value=21.0), call(title="tag/gamma", series="1", iteration=5, value=6.0), ], any_order=True, )
def test_output_handler_metric_names(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"]) mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 2 mock_logger.clearml_logger.report_scalar.assert_has_calls( [ call(title="tag", series="a", iteration=5, value=12.23), call(title="tag", series="b", iteration=5, value=23.45), ], any_order=True, ) wrapper = OutputHandler("tag", metric_names=["a", "c"]) mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) mock_engine.state.iteration = 7 mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() with pytest.warns( UserWarning, match=r"ClearMLLogger output_handler can not log metrics value type" ): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 1 mock_logger.clearml_logger.report_scalar.assert_has_calls( [call(title="tag", series="a", iteration=7, value=55.56)], any_order=True) # all metrics wrapper = OutputHandler("tag", metric_names="all") mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 2 mock_logger.clearml_logger.report_scalar.assert_has_calls( [ call(title="tag", series="a", iteration=5, value=12.23), call(title="tag", series="b", iteration=5, value=23.45), ], any_order=True, ) # log a torch vector wrapper = OutputHandler("tag", metric_names="all") mock_logger = MagicMock(spec=ClearMLLogger) mock_logger.clearml_logger = MagicMock() mock_engine = MagicMock() vector = torch.tensor([0.1, 0.2, 0.1, 0.2, 0.33]) mock_engine.state = State(metrics={"vector": vector}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.clearml_logger.report_scalar.call_count == 5 mock_logger.clearml_logger.report_scalar.assert_has_calls( [ call(title="tag/vector", series=str(i), iteration=5, value=vector[i].item()) for i in range(5) ], any_order=True, )