def test_output_handler_with_wrong_logger_type():
    wrapper = OutputHandler("tag", output_transform=lambda x: x)

    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(TypeError, match="Handler OutputHandler works only with NeptuneLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_output_handler_with_global_step_from_engine():
    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metric.call_count == 1
    mock_logger.log_metric.assert_has_calls(
        [call("tag/loss", y=mock_engine.state.output, x=mock_another_engine.state.epoch)]
    )

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metric.call_count == 2
    mock_logger.log_metric.assert_has_calls(
        [call("tag/loss", y=mock_engine.state.output, x=mock_another_engine.state.epoch)]
    )
def test_output_handler_output_transform():
    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("tag/output", y=12345, x=123)

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("another_tag/loss", y=12345, x=123)
def test_output_handler_with_wrong_global_step_transform_output():
    def global_step_transform(*args, **kwargs):
        return "a"

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    with pytest.raises(TypeError, match="global_step must be int"):
        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metric.call_count == 1
    mock_logger.log_metric.assert_has_calls([call("tag/loss", y=12345, x=10)])
def test_output_handler_both():
    wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.log_metric.call_count == 3
    mock_logger.log_metric.assert_has_calls(
        [call("tag/a", y=12.23, x=5), call("tag/b", y=23.45, x=5), call("tag/loss", y=12345, x=5)], any_order=True
    )
Beispiel #7
0
def test_output_handler_state_attrs():
    wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 5
    mock_engine.state.alpha = 3.899
    mock_engine.state.beta = torch.tensor(12.23)
    mock_engine.state.gamma = torch.tensor([21.0, 6.0])

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metric.call_count == 4
    mock_logger.log_metric.assert_has_calls(
        [
            call("tag/alpha", y=3.899, x=5),
            call("tag/beta", y=torch.tensor(12.23).item(), x=5),
            call("tag/gamma/0", y=21.0, x=5),
            call("tag/gamma/1", y=6.0, x=5),
        ],
        any_order=True,
    )
Beispiel #8
0
def test_output_handler_metric_names():
    wrapper = OutputHandler("tag", metric_names=["a", "b"])
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metric.call_count == 2
    mock_logger.log_metric.assert_has_calls(
        [call("tag/a", y=12.23, x=5),
         call("tag/b", y=23.45, x=5)],
        any_order=True)

    wrapper = OutputHandler("tag", metric_names=["a"])

    mock_engine = MagicMock()
    mock_logger.log_metric = MagicMock()
    mock_engine.state = State(
        metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])})
    mock_engine.state.iteration = 5

    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metric.call_count == 4
    mock_logger.log_metric.assert_has_calls(
        [
            call("tag/a/0", y=0.0, x=5),
            call("tag/a/1", y=1.0, x=5),
            call("tag/a/2", y=2.0, x=5),
            call("tag/a/3", y=3.0, x=5),
        ],
        any_order=True,
    )

    wrapper = OutputHandler("tag", metric_names=["a", "c"])

    mock_engine = MagicMock()
    mock_logger.log_metric = MagicMock()
    mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"})
    mock_engine.state.iteration = 7

    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()

    with pytest.warns(UserWarning):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metric.call_count == 1
    mock_logger.log_metric.assert_has_calls([call("tag/a", y=55.56, x=7)],
                                            any_order=True)

    # all metrics
    wrapper = OutputHandler("tag", metric_names="all")
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metric.call_count == 2
    mock_logger.log_metric.assert_has_calls(
        [call("tag/a", y=12.23, x=5),
         call("tag/b", y=23.45, x=5)],
        any_order=True)

    # log a torch tensor (ndimension = 0)
    wrapper = OutputHandler("tag", metric_names="all")
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State(metrics={
        "a": torch.tensor(12.23),
        "b": torch.tensor(23.45)
    })
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metric.call_count == 2
    mock_logger.log_metric.assert_has_calls(
        [
            call("tag/a", y=torch.tensor(12.23).item(), x=5),
            call("tag/b", y=torch.tensor(23.45).item(), x=5)
        ],
        any_order=True,
    )