Ejemplo n.º 1
0
def test_output_handler_output_transform(dirname):

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "tag/output" in wrapper.windows
    assert wrapper.windows["tag/output"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            12345,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["tag/output"]["opts"],
        name="tag/output",
    )

    wrapper = OutputHandler("another_tag",
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "another_tag/loss" in wrapper.windows
    assert wrapper.windows["another_tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            12345,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["another_tag/loss"]["opts"],
        name="another_tag/loss",
    )
Ejemplo n.º 2
0
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag",
                            output_transform=lambda x: {"loss": x},
                            global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 1
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_has_calls([
        call(
            X=[10],
            Y=[12345],
            env=mock_logger.vis.env,
            win=None,
            update=None,
            opts=wrapper.windows["tag/loss"]["opts"],
            name="tag/loss",
        )
    ])
Ejemplo n.º 3
0
def test_output_handler_with_global_step_from_engine():

    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 1
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None
    mock_logger.vis.line.assert_has_calls([
        call(
            X=[mock_another_engine.state.epoch],
            Y=[mock_engine.state.output],
            env=mock_logger.vis.env,
            win=None,
            update=None,
            opts=wrapper.windows["tag/loss"]["opts"],
            name="tag/loss",
        )
    ])

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 2
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None
    mock_logger.vis.line.assert_has_calls([
        call(
            X=[mock_another_engine.state.epoch],
            Y=[mock_engine.state.output],
            env=mock_logger.vis.env,
            win=wrapper.windows["tag/loss"]["win"],
            update="append",
            opts=wrapper.windows["tag/loss"]["opts"],
            name="tag/loss",
        )
    ])
Ejemplo n.º 4
0
def test_output_handler_with_wrong_logger_type():

    wrapper = OutputHandler("tag", output_transform=lambda x: x)

    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(
            RuntimeError,
            match="Handler 'OutputHandler' works only with VisdomLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
Ejemplo n.º 5
0
def test_output_handler_with_wrong_global_step_transform_output():
    def global_step_transform(*args, **kwargs):
        return "a"

    wrapper = OutputHandler("tag",
                            output_transform=lambda x: {"loss": x},
                            global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    with pytest.raises(TypeError, match="global_step must be int"):
        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
Ejemplo n.º 6
0
def test_integration_with_executor_as_context_manager(visdom_server,
                                                      visdom_server_stop):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    with VisdomLogger(server=visdom_server[0],
                      port=visdom_server[1],
                      num_workers=1) as vd_logger:

        # close all windows in 'main' environment
        vd_logger.vis.close()

        trainer = Engine(update_fn)
        output_handler = OutputHandler(tag="training",
                                       output_transform=lambda x: {"loss": x})
        vd_logger.attach(trainer,
                         log_handler=output_handler,
                         event_name=Events.ITERATION_COMPLETED)

        trainer.run(data, max_epochs=n_epochs)

        assert len(output_handler.windows) == 1
        assert "training/loss" in output_handler.windows
        win_name = output_handler.windows["training/loss"]["win"]
        data = vd_logger.vis.get_window_data(win=win_name)
        data = _parse_content(data)
        assert "content" in data and "data" in data["content"]
        data = data["content"]["data"][0]
        assert "x" in data and "y" in data
        x_vals, y_vals = data["x"], data["y"]
        assert all([
            int(x) == x_true for x, x_true in zip(
                x_vals, list(range(1,
                                   n_epochs * len(data) + 1)))
        ])
        assert all([y == y_true for y, y_true in zip(y_vals, losses)])
Ejemplo n.º 7
0
def test_output_handler_both(dirname):

    wrapper = OutputHandler("tag",
                            metric_names=["a", "b"],
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.vis.line.call_count == 3
    assert (len(wrapper.windows) == 3 and "tag/a" in wrapper.windows
            and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows)
    assert wrapper.windows["tag/a"]["win"] is not None
    assert wrapper.windows["tag/b"]["win"] is not None
    assert wrapper.windows["tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    5,
                ],
                Y=[
                    12.23,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a"]["opts"],
                name="tag/a",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    23.45,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/b"]["opts"],
                name="tag/b",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    12345,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/loss"]["opts"],
                name="tag/loss",
            ),
        ],
        any_order=True,
    )

    mock_engine.state.epoch = 6
    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.vis.line.call_count == 6
    assert (len(wrapper.windows) == 3 and "tag/a" in wrapper.windows
            and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows)
    assert wrapper.windows["tag/a"]["win"] is not None
    assert wrapper.windows["tag/b"]["win"] is not None
    assert wrapper.windows["tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    6,
                ],
                Y=[
                    12.23,
                ],
                env=mock_logger.vis.env,
                win=wrapper.windows["tag/a"]["win"],
                update="append",
                opts=wrapper.windows["tag/a"]["opts"],
                name="tag/a",
            ),
            call(
                X=[
                    6,
                ],
                Y=[
                    23.45,
                ],
                env=mock_logger.vis.env,
                win=wrapper.windows["tag/b"]["win"],
                update="append",
                opts=wrapper.windows["tag/b"]["opts"],
                name="tag/b",
            ),
            call(
                X=[
                    6,
                ],
                Y=[
                    12345,
                ],
                env=mock_logger.vis.env,
                win=wrapper.windows["tag/loss"]["win"],
                update="append",
                opts=wrapper.windows["tag/loss"]["opts"],
                name="tag/loss",
            ),
        ],
        any_order=True,
    )
Ejemplo n.º 8
0
def test_output_handler_metric_names(dirname):

    wrapper = OutputHandler("tag", metric_names=["a", "b"])
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(
        wrapper.windows
    ) == 2 and "tag/a" in wrapper.windows and "tag/b" in wrapper.windows
    assert wrapper.windows["tag/a"]["win"] is not None
    assert wrapper.windows["tag/b"]["win"] is not None

    assert mock_logger.vis.line.call_count == 2
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    5,
                ],
                Y=[
                    12.23,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a"]["opts"],
                name="tag/a",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    23.45,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/b"]["opts"],
                name="tag/b",
            ),
        ],
        any_order=True,
    )

    wrapper = OutputHandler("tag", metric_names=[
        "a",
    ])

    mock_engine = MagicMock()
    mock_engine.state = State(
        metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])})
    mock_engine.state.iteration = 5

    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 4 and all(
        [f"tag/a/{i}" in wrapper.windows for i in range(4)])
    assert wrapper.windows["tag/a/0"]["win"] is not None
    assert wrapper.windows["tag/a/1"]["win"] is not None
    assert wrapper.windows["tag/a/2"]["win"] is not None
    assert wrapper.windows["tag/a/3"]["win"] is not None

    assert mock_logger.vis.line.call_count == 4
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    5,
                ],
                Y=[
                    0.0,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a/0"]["opts"],
                name="tag/a/0",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    1.0,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a/1"]["opts"],
                name="tag/a/1",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    2.0,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a/2"]["opts"],
                name="tag/a/2",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    3.0,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a/3"]["opts"],
                name="tag/a/3",
            ),
        ],
        any_order=True,
    )

    wrapper = OutputHandler("tag", metric_names=["a", "c"])

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"})
    mock_engine.state.iteration = 7

    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    with pytest.warns(UserWarning):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "tag/a" in wrapper.windows
    assert wrapper.windows["tag/a"]["win"] is not None

    assert mock_logger.vis.line.call_count == 1
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    7,
                ],
                Y=[
                    55.56,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a"]["opts"],
                name="tag/a",
            ),
        ],
        any_order=True,
    )

    # all metrics
    wrapper = OutputHandler("tag", metric_names="all")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(
        wrapper.windows
    ) == 2 and "tag/a" in wrapper.windows and "tag/b" in wrapper.windows
    assert wrapper.windows["tag/a"]["win"] is not None
    assert wrapper.windows["tag/b"]["win"] is not None

    assert mock_logger.vis.line.call_count == 2
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    5,
                ],
                Y=[
                    12.23,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/a"]["opts"],
                name="tag/a",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    23.45,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/b"]["opts"],
                name="tag/b",
            ),
        ],
        any_order=True,
    )
Ejemplo n.º 9
0
def test_output_handler_state_attrs():
    wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 5
    mock_engine.state.alpha = 3.899
    mock_engine.state.beta = torch.tensor(12.0)
    mock_engine.state.gamma = torch.tensor([21.0, 6.0])

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.vis.line.call_count == 4
    assert (len(wrapper.windows) == 4 and "tag/alpha" in wrapper.windows
            and "tag/beta" in wrapper.windows
            and "tag/gamma/0" in wrapper.windows
            and "tag/gamma/1" in wrapper.windows)
    assert wrapper.windows["tag/alpha"]["win"] is not None
    assert wrapper.windows["tag/beta"]["win"] is not None
    assert wrapper.windows["tag/gamma/0"]["win"] is not None
    assert wrapper.windows["tag/gamma/1"]["win"] is not None

    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[5],
                Y=[3.899],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/alpha"]["opts"],
                name="tag/alpha",
            ),
            call(
                X=[5],
                Y=[12.0],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/beta"]["opts"],
                name="tag/beta",
            ),
            call(
                X=[5],
                Y=[21.0],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/gamma/0"]["opts"],
                name="tag/gamma/0",
            ),
            call(
                X=[5],
                Y=[6.0],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/gamma/1"]["opts"],
                name="tag/gamma/1",
            ),
        ],
        any_order=True,
    )