Beispiel #1
0
def test_output_handler_with_wrong_logger_type():

    wrapper = OutputHandler("tag", output_transform=lambda x: x)

    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(
            RuntimeError,
            match="Handler OutputHandler works only with ClearMLLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
Beispiel #2
0
def test_output_handler_output_transform(dirname):

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.clearml_logger.report_scalar.assert_called_once_with(
        iteration=123, series="output", title="tag", value=12345)

    wrapper = OutputHandler("another_tag",
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.clearml_logger.report_scalar.assert_called_once_with(
        iteration=123, series="loss", title="another_tag", value=12345)
Beispiel #3
0
def test_output_handler_with_wrong_global_step_transform_output():
    def global_step_transform(*args, **kwargs):
        return "a"

    wrapper = OutputHandler("tag",
                            output_transform=lambda x: {"loss": x},
                            global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    with pytest.raises(TypeError, match="global_step must be int"):
        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
Beispiel #4
0
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag",
                            output_transform=lambda x: {"loss": x},
                            global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.clearml_logger.report_scalar.call_count == 1
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [call(title="tag", series="loss", iteration=10, value=12345)])
Beispiel #5
0
def test_output_handler_with_global_step_from_engine():

    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.clearml_logger.report_scalar.call_count == 1
    mock_logger.clearml_logger.report_scalar.assert_has_calls([
        call(title="tag",
             series="loss",
             iteration=mock_another_engine.state.epoch,
             value=mock_engine.state.output)
    ])

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.clearml_logger.report_scalar.call_count == 2
    mock_logger.clearml_logger.report_scalar.assert_has_calls([
        call(title="tag",
             series="loss",
             iteration=mock_another_engine.state.epoch,
             value=mock_engine.state.output)
    ])
Beispiel #6
0
def test_output_handler_both(dirname):

    wrapper = OutputHandler("tag",
                            metric_names=["a", "b"],
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.clearml_logger.report_scalar.call_count == 3
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [
            call(title="tag", series="a", iteration=5, value=12.23),
            call(title="tag", series="b", iteration=5, value=23.45),
            call(title="tag", series="loss", iteration=5, value=12345),
        ],
        any_order=True,
    )
Beispiel #7
0
def test_output_handler_state_attrs():
    wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 5
    mock_engine.state.alpha = 3.899
    mock_engine.state.beta = torch.tensor(12.0)
    mock_engine.state.gamma = torch.tensor([21.0, 6.0])

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.clearml_logger.report_scalar.call_count == 4
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [
            call(title="tag", series="alpha", iteration=5, value=3.899),
            call(title="tag", series="beta", iteration=5, value=12.0),
            call(title="tag/gamma", series="0", iteration=5, value=21.0),
            call(title="tag/gamma", series="1", iteration=5, value=6.0),
        ],
        any_order=True,
    )
Beispiel #8
0
def test_output_handler_metric_names(dirname):

    wrapper = OutputHandler("tag", metric_names=["a", "b"])
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.clearml_logger.report_scalar.call_count == 2
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [
            call(title="tag", series="a", iteration=5, value=12.23),
            call(title="tag", series="b", iteration=5, value=23.45),
        ],
        any_order=True,
    )

    wrapper = OutputHandler("tag", metric_names=["a", "c"])

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"})
    mock_engine.state.iteration = 7

    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    with pytest.warns(
            UserWarning,
            match=r"ClearMLLogger output_handler can not log metrics value type"
    ):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.clearml_logger.report_scalar.call_count == 1
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [call(title="tag", series="a", iteration=7, value=55.56)],
        any_order=True)

    # all metrics
    wrapper = OutputHandler("tag", metric_names="all")
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.clearml_logger.report_scalar.call_count == 2
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [
            call(title="tag", series="a", iteration=5, value=12.23),
            call(title="tag", series="b", iteration=5, value=23.45),
        ],
        any_order=True,
    )

    # log a torch vector
    wrapper = OutputHandler("tag", metric_names="all")
    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    vector = torch.tensor([0.1, 0.2, 0.1, 0.2, 0.33])
    mock_engine.state = State(metrics={"vector": vector})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.clearml_logger.report_scalar.call_count == 5
    mock_logger.clearml_logger.report_scalar.assert_has_calls(
        [
            call(title="tag/vector",
                 series=str(i),
                 iteration=5,
                 value=vector[i].item()) for i in range(5)
        ],
        any_order=True,
    )