Exemple #1
0
def test_optimizer_params():

    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    # mock_logger.vis.line.assert_called_once_with("lr/group_0", 0.01, 123)
    assert len(wrapper.windows) == 1 and "lr/group_0" in wrapper.windows
    assert wrapper.windows["lr/group_0"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            0.01,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["lr/group_0"]["opts"],
        name="lr/group_0",
    )

    wrapper = OptimizerParamsHandler(optimizer=optimizer,
                                     param_name="lr",
                                     tag="generator")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(
        wrapper.windows) == 1 and "generator/lr/group_0" in wrapper.windows
    assert wrapper.windows["generator/lr/group_0"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            0.01,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["generator/lr/group_0"]["opts"],
        name="generator/lr/group_0",
    )
Exemple #2
0
def test_output_handler_output_transform(dirname):

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "tag/output" in wrapper.windows
    assert wrapper.windows["tag/output"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            12345,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["tag/output"]["opts"],
        name="tag/output",
    )

    wrapper = OutputHandler("another_tag",
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "another_tag/loss" in wrapper.windows
    assert wrapper.windows["another_tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            12345,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["another_tag/loss"]["opts"],
        name="another_tag/loss",
    )
Exemple #3
0
    def _test(tag=None):
        wrapper = GradsScalarHandler(model, reduction=norm, tag=tag)
        mock_logger = MagicMock(spec=VisdomLogger)
        mock_logger.vis = MagicMock()
        mock_logger.executor = _DummyExecutor()

        mock_engine = MagicMock()
        mock_engine.state = State()
        mock_engine.state.epoch = 5

        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

        tag_prefix = "{}/".format(tag) if tag else ""

        assert mock_logger.vis.line.call_count == 4
        mock_logger.vis.line.assert_has_calls([
            call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
                 win=None, update=None,
                 opts=wrapper.windows[tag_prefix + "grads_norm/fc1/weight"]['opts'],
                 name=tag_prefix + "grads_norm/fc1/weight"),
            call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
                 win=None, update=None,
                 opts=wrapper.windows[tag_prefix + "grads_norm/fc1/bias"]['opts'],
                 name=tag_prefix + "grads_norm/fc1/bias"),

            call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
                 win=None, update=None,
                 opts=wrapper.windows[tag_prefix + "grads_norm/fc2/weight"]['opts'],
                 name=tag_prefix + "grads_norm/fc2/weight"),
            call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
                 win=None, update=None,
                 opts=wrapper.windows[tag_prefix + "grads_norm/fc2/bias"]['opts'],
                 name=tag_prefix + "grads_norm/fc2/bias"),
        ], any_order=True)
def test_output_handler_with_global_step_from_engine():

    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 1
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[mock_another_engine.state.epoch,],
                Y=[mock_engine.state.output,],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/loss"]["opts"],
                name="tag/loss",
            )
        ]
    )

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 2
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[mock_another_engine.state.epoch,],
                Y=[mock_engine.state.output,],
                env=mock_logger.vis.env,
                win=wrapper.windows["tag/loss"]["win"],
                update="append",
                opts=wrapper.windows["tag/loss"]["opts"],
                name="tag/loss",
            )
        ]
    )
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 1
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[10,],
                Y=[12345,],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/loss"]["opts"],
                name="tag/loss",
            )
        ]
    )
Exemple #6
0
    def _test(tag=None):
        wrapper = WeightsScalarHandler(model, tag=tag)
        mock_logger = MagicMock(spec=VisdomLogger)
        mock_logger.vis = MagicMock()
        mock_logger.executor = _DummyExecutor()

        mock_engine = MagicMock()
        mock_engine.state = State()
        mock_engine.state.epoch = 5

        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

        tag_prefix = f"{tag}/" if tag else ""

        assert mock_logger.vis.line.call_count == 4
        mock_logger.vis.line.assert_has_calls(
            [
                call(
                    X=[5],
                    Y=[0.0],
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc1/weight"]["opts"],
                    name=tag_prefix + "weights_norm/fc1/weight",
                ),
                call(
                    X=[5],
                    Y=[0.0],
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc1/bias"]["opts"],
                    name=tag_prefix + "weights_norm/fc1/bias",
                ),
                call(
                    X=[5],
                    Y=[12.0],
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc2/weight"]["opts"],
                    name=tag_prefix + "weights_norm/fc2/weight",
                ),
                call(
                    X=[5],
                    Y=ANY,
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc2/bias"]["opts"],
                    name=tag_prefix + "weights_norm/fc2/bias",
                ),
            ],
            any_order=True,
        )
Exemple #7
0
def test_output_handler_both(dirname):

    wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.vis.line.call_count == 3
    assert len(wrapper.windows) == 3 and \
        "tag/a" in wrapper.windows and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/a"]['win'] is not None
    assert wrapper.windows["tag/b"]['win'] is not None
    assert wrapper.windows["tag/loss"]['win'] is not None

    mock_logger.vis.line.assert_has_calls([
        call(X=[5, ], Y=[12.23, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a']['opts'], name="tag/a"),
        call(X=[5, ], Y=[23.45, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/b']['opts'], name="tag/b"),
        call(X=[5, ], Y=[12345, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/loss']['opts'], name="tag/loss"),
    ], any_order=True)

    mock_engine.state.epoch = 6
    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.vis.line.call_count == 6
    assert len(wrapper.windows) == 3 and \
        "tag/a" in wrapper.windows and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/a"]['win'] is not None
    assert wrapper.windows["tag/b"]['win'] is not None
    assert wrapper.windows["tag/loss"]['win'] is not None

    mock_logger.vis.line.assert_has_calls([
        call(X=[6, ], Y=[12.23, ], env=mock_logger.vis.env,
             win=wrapper.windows["tag/a"]['win'], update='append',
             opts=wrapper.windows['tag/a']['opts'], name="tag/a"),
        call(X=[6, ], Y=[23.45, ], env=mock_logger.vis.env,
             win=wrapper.windows["tag/b"]['win'], update='append',
             opts=wrapper.windows['tag/b']['opts'], name="tag/b"),
        call(X=[6, ], Y=[12345, ], env=mock_logger.vis.env,
             win=wrapper.windows["tag/loss"]['win'], update='append',
             opts=wrapper.windows['tag/loss']['opts'], name="tag/loss"),
    ], any_order=True)
Exemple #8
0
def test_output_handler_with_wrong_global_step_transform_output():
    def global_step_transform(*args, **kwargs):
        return 'a'

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    with pytest.raises(TypeError, match="global_step must be int"):
        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
Exemple #9
0
def test_grads_scalar_handler():

    class DummyModel(torch.nn.Module):

        def __init__(self):
            super(DummyModel, self).__init__()
            self.fc1 = torch.nn.Linear(10, 10)
            self.fc2 = torch.nn.Linear(12, 12)
            self.fc1.weight.data.zero_()
            self.fc1.bias.data.zero_()
            self.fc2.weight.data.fill_(1.0)
            self.fc2.bias.data.fill_(1.0)

    model = DummyModel()

    def norm(x):
        return 0.0

    wrapper = GradsScalarHandler(model, reduction=norm)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.vis.line.call_count == 4

    mock_logger.vis.line.assert_has_calls([
        call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows["grads_norm/fc1/weight"]['opts'], name="grads_norm/fc1/weight"),
        call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows["grads_norm/fc1/bias"]['opts'], name="grads_norm/fc1/bias"),

        call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows["grads_norm/fc2/weight"]['opts'], name="grads_norm/fc2/weight"),
        call(X=[5, ], Y=ANY, env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows["grads_norm/fc2/bias"]['opts'], name="grads_norm/fc2/bias"),

    ], any_order=True)
Exemple #10
0
def test_output_handler_metric_names(dirname):

    wrapper = OutputHandler("tag", metric_names=["a", "b"])
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 2 and \
        "tag/a" in wrapper.windows and "tag/b" in wrapper.windows
    assert wrapper.windows["tag/a"]['win'] is not None
    assert wrapper.windows["tag/b"]['win'] is not None

    assert mock_logger.vis.line.call_count == 2
    mock_logger.vis.line.assert_has_calls([
        call(X=[5, ], Y=[12.23, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a']['opts'], name="tag/a"),
        call(X=[5, ], Y=[23.45, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/b']['opts'], name="tag/b"),
    ], any_order=True)

    wrapper = OutputHandler("tag", metric_names=["a", ])

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])})
    mock_engine.state.iteration = 5

    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 4 and \
        all(["tag/a/{}".format(i) in wrapper.windows for i in range(4)])
    assert wrapper.windows["tag/a/0"]['win'] is not None
    assert wrapper.windows["tag/a/1"]['win'] is not None
    assert wrapper.windows["tag/a/2"]['win'] is not None
    assert wrapper.windows["tag/a/3"]['win'] is not None

    assert mock_logger.vis.line.call_count == 4
    mock_logger.vis.line.assert_has_calls([
        call(X=[5, ], Y=[0.0, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a/0']['opts'], name="tag/a/0"),
        call(X=[5, ], Y=[1.0, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a/1']['opts'], name="tag/a/1"),
        call(X=[5, ], Y=[2.0, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a/2']['opts'], name="tag/a/2"),
        call(X=[5, ], Y=[3.0, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a/3']['opts'], name="tag/a/3"),
    ], any_order=True)

    wrapper = OutputHandler("tag", metric_names=["a", "c"])

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"})
    mock_engine.state.iteration = 7

    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    with pytest.warns(UserWarning):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "tag/a" in wrapper.windows
    assert wrapper.windows["tag/a"]['win'] is not None

    assert mock_logger.vis.line.call_count == 1
    mock_logger.vis.line.assert_has_calls([
        call(X=[7, ], Y=[55.56, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a']['opts'], name="tag/a"),
    ], any_order=True)

    # all metrics
    wrapper = OutputHandler("tag", metric_names="all")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 2 and \
        "tag/a" in wrapper.windows and "tag/b" in wrapper.windows
    assert wrapper.windows["tag/a"]['win'] is not None
    assert wrapper.windows["tag/b"]['win'] is not None

    assert mock_logger.vis.line.call_count == 2
    mock_logger.vis.line.assert_has_calls([
        call(X=[5, ], Y=[12.23, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/a']['opts'], name="tag/a"),
        call(X=[5, ], Y=[23.45, ], env=mock_logger.vis.env,
             win=None, update=None,
             opts=wrapper.windows['tag/b']['opts'], name="tag/b"),
    ], any_order=True)
Exemple #11
0
def test_output_handler_state_attrs():
    wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 5
    mock_engine.state.alpha = 3.899
    mock_engine.state.beta = torch.tensor(12.0)
    mock_engine.state.gamma = torch.tensor([21.0, 6.0])

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.vis.line.call_count == 4
    assert (len(wrapper.windows) == 4 and "tag/alpha" in wrapper.windows
            and "tag/beta" in wrapper.windows
            and "tag/gamma/0" in wrapper.windows
            and "tag/gamma/1" in wrapper.windows)
    assert wrapper.windows["tag/alpha"]["win"] is not None
    assert wrapper.windows["tag/beta"]["win"] is not None
    assert wrapper.windows["tag/gamma/0"]["win"] is not None
    assert wrapper.windows["tag/gamma/1"]["win"] is not None

    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[5],
                Y=[3.899],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/alpha"]["opts"],
                name="tag/alpha",
            ),
            call(
                X=[5],
                Y=[12.0],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/beta"]["opts"],
                name="tag/beta",
            ),
            call(
                X=[5],
                Y=[21.0],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/gamma/0"]["opts"],
                name="tag/gamma/0",
            ),
            call(
                X=[5],
                Y=[6.0],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["tag/gamma/1"]["opts"],
                name="tag/gamma/1",
            ),
        ],
        any_order=True,
    )