def test_optimizer_params(): optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) # mock_logger.vis.line.assert_called_once_with("lr/group_0", 0.01, 123) assert len(wrapper.windows) == 1 and "lr/group_0" in wrapper.windows assert wrapper.windows["lr/group_0"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 0.01, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["lr/group_0"]["opts"], name="lr/group_0", ) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr", tag="generator") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len( wrapper.windows) == 1 and "generator/lr/group_0" in wrapper.windows assert wrapper.windows["generator/lr/group_0"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 0.01, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["generator/lr/group_0"]["opts"], name="generator/lr/group_0", )
def test_output_handler_output_transform(dirname): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "tag/output" in wrapper.windows assert wrapper.windows["tag/output"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/output"]["opts"], name="tag/output", ) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "another_tag/loss" in wrapper.windows assert wrapper.windows["another_tag/loss"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["another_tag/loss"]["opts"], name="another_tag/loss", )
def _test(tag=None): wrapper = GradsScalarHandler(model, reduction=norm, tag=tag) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) tag_prefix = "{}/".format(tag) if tag else "" assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls([ call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "grads_norm/fc1/weight"]['opts'], name=tag_prefix + "grads_norm/fc1/weight"), call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "grads_norm/fc1/bias"]['opts'], name=tag_prefix + "grads_norm/fc1/bias"), call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "grads_norm/fc2/weight"]['opts'], name=tag_prefix + "grads_norm/fc2/weight"), call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "grads_norm/fc2/bias"]['opts'], name=tag_prefix + "grads_norm/fc2/bias"), ], any_order=True)
def test_output_handler_with_global_step_from_engine(): mock_another_engine = MagicMock() mock_another_engine.state = State() mock_another_engine.state.epoch = 10 mock_another_engine.state.output = 12.345 wrapper = OutputHandler( "tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_from_engine(mock_another_engine), ) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 1 mock_engine.state.output = 0.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 1 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[mock_another_engine.state.epoch,], Y=[mock_engine.state.output,], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ] ) mock_another_engine.state.epoch = 11 mock_engine.state.output = 1.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 2 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[mock_another_engine.state.epoch,], Y=[mock_engine.state.output,], env=mock_logger.vis.env, win=wrapper.windows["tag/loss"]["win"], update="append", opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ] )
def test_output_handler_with_global_step_transform(): def global_step_transform(*args, **kwargs): return 10 wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 1 assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows assert wrapper.windows["tag/loss"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[10,], Y=[12345,], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/loss"]["opts"], name="tag/loss", ) ] )
def _test(tag=None): wrapper = WeightsScalarHandler(model, tag=tag) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) tag_prefix = f"{tag}/" if tag else "" assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls( [ call( X=[5], Y=[0.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc1/weight"]["opts"], name=tag_prefix + "weights_norm/fc1/weight", ), call( X=[5], Y=[0.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc1/bias"]["opts"], name=tag_prefix + "weights_norm/fc1/bias", ), call( X=[5], Y=[12.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc2/weight"]["opts"], name=tag_prefix + "weights_norm/fc2/weight", ), call( X=[5], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc2/bias"]["opts"], name=tag_prefix + "weights_norm/fc2/bias", ), ], any_order=True, )
def test_output_handler_both(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 3 assert len(wrapper.windows) == 3 and \ "tag/a" in wrapper.windows and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows assert wrapper.windows["tag/a"]['win'] is not None assert wrapper.windows["tag/b"]['win'] is not None assert wrapper.windows["tag/loss"]['win'] is not None mock_logger.vis.line.assert_has_calls([ call(X=[5, ], Y=[12.23, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a']['opts'], name="tag/a"), call(X=[5, ], Y=[23.45, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/b']['opts'], name="tag/b"), call(X=[5, ], Y=[12345, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/loss']['opts'], name="tag/loss"), ], any_order=True) mock_engine.state.epoch = 6 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 6 assert len(wrapper.windows) == 3 and \ "tag/a" in wrapper.windows and "tag/b" in wrapper.windows and "tag/loss" in wrapper.windows assert wrapper.windows["tag/a"]['win'] is not None assert wrapper.windows["tag/b"]['win'] is not None assert wrapper.windows["tag/loss"]['win'] is not None mock_logger.vis.line.assert_has_calls([ call(X=[6, ], Y=[12.23, ], env=mock_logger.vis.env, win=wrapper.windows["tag/a"]['win'], update='append', opts=wrapper.windows['tag/a']['opts'], name="tag/a"), call(X=[6, ], Y=[23.45, ], env=mock_logger.vis.env, win=wrapper.windows["tag/b"]['win'], update='append', opts=wrapper.windows['tag/b']['opts'], name="tag/b"), call(X=[6, ], Y=[12345, ], env=mock_logger.vis.env, win=wrapper.windows["tag/loss"]['win'], update='append', opts=wrapper.windows['tag/loss']['opts'], name="tag/loss"), ], any_order=True)
def test_output_handler_with_wrong_global_step_transform_output(): def global_step_transform(*args, **kwargs): return 'a' wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 with pytest.raises(TypeError, match="global_step must be int"): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
def test_grads_scalar_handler(): class DummyModel(torch.nn.Module): def __init__(self): super(DummyModel, self).__init__() self.fc1 = torch.nn.Linear(10, 10) self.fc2 = torch.nn.Linear(12, 12) self.fc1.weight.data.zero_() self.fc1.bias.data.zero_() self.fc2.weight.data.fill_(1.0) self.fc2.bias.data.fill_(1.0) model = DummyModel() def norm(x): return 0.0 wrapper = GradsScalarHandler(model, reduction=norm) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls([ call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["grads_norm/fc1/weight"]['opts'], name="grads_norm/fc1/weight"), call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["grads_norm/fc1/bias"]['opts'], name="grads_norm/fc1/bias"), call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["grads_norm/fc2/weight"]['opts'], name="grads_norm/fc2/weight"), call(X=[5, ], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["grads_norm/fc2/bias"]['opts'], name="grads_norm/fc2/bias"), ], any_order=True)
def test_output_handler_metric_names(dirname): wrapper = OutputHandler("tag", metric_names=["a", "b"]) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 2 and \ "tag/a" in wrapper.windows and "tag/b" in wrapper.windows assert wrapper.windows["tag/a"]['win'] is not None assert wrapper.windows["tag/b"]['win'] is not None assert mock_logger.vis.line.call_count == 2 mock_logger.vis.line.assert_has_calls([ call(X=[5, ], Y=[12.23, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a']['opts'], name="tag/a"), call(X=[5, ], Y=[23.45, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/b']['opts'], name="tag/b"), ], any_order=True) wrapper = OutputHandler("tag", metric_names=["a", ]) mock_engine = MagicMock() mock_engine.state = State(metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])}) mock_engine.state.iteration = 5 mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 4 and \ all(["tag/a/{}".format(i) in wrapper.windows for i in range(4)]) assert wrapper.windows["tag/a/0"]['win'] is not None assert wrapper.windows["tag/a/1"]['win'] is not None assert wrapper.windows["tag/a/2"]['win'] is not None assert wrapper.windows["tag/a/3"]['win'] is not None assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls([ call(X=[5, ], Y=[0.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a/0']['opts'], name="tag/a/0"), call(X=[5, ], Y=[1.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a/1']['opts'], name="tag/a/1"), call(X=[5, ], Y=[2.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a/2']['opts'], name="tag/a/2"), call(X=[5, ], Y=[3.0, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a/3']['opts'], name="tag/a/3"), ], any_order=True) wrapper = OutputHandler("tag", metric_names=["a", "c"]) mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) mock_engine.state.iteration = 7 mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() with pytest.warns(UserWarning): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 1 and "tag/a" in wrapper.windows assert wrapper.windows["tag/a"]['win'] is not None assert mock_logger.vis.line.call_count == 1 mock_logger.vis.line.assert_has_calls([ call(X=[7, ], Y=[55.56, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a']['opts'], name="tag/a"), ], any_order=True) # all metrics wrapper = OutputHandler("tag", metric_names="all") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len(wrapper.windows) == 2 and \ "tag/a" in wrapper.windows and "tag/b" in wrapper.windows assert wrapper.windows["tag/a"]['win'] is not None assert wrapper.windows["tag/b"]['win'] is not None assert mock_logger.vis.line.call_count == 2 mock_logger.vis.line.assert_has_calls([ call(X=[5, ], Y=[12.23, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/a']['opts'], name="tag/a"), call(X=[5, ], Y=[23.45, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows['tag/b']['opts'], name="tag/b"), ], any_order=True)
def test_output_handler_state_attrs(): wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 5 mock_engine.state.alpha = 3.899 mock_engine.state.beta = torch.tensor(12.0) mock_engine.state.gamma = torch.tensor([21.0, 6.0]) wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.vis.line.call_count == 4 assert (len(wrapper.windows) == 4 and "tag/alpha" in wrapper.windows and "tag/beta" in wrapper.windows and "tag/gamma/0" in wrapper.windows and "tag/gamma/1" in wrapper.windows) assert wrapper.windows["tag/alpha"]["win"] is not None assert wrapper.windows["tag/beta"]["win"] is not None assert wrapper.windows["tag/gamma/0"]["win"] is not None assert wrapper.windows["tag/gamma/1"]["win"] is not None mock_logger.vis.line.assert_has_calls( [ call( X=[5], Y=[3.899], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/alpha"]["opts"], name="tag/alpha", ), call( X=[5], Y=[12.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/beta"]["opts"], name="tag/beta", ), call( X=[5], Y=[21.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/gamma/0"]["opts"], name="tag/gamma/0", ), call( X=[5], Y=[6.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["tag/gamma/1"]["opts"], name="tag/gamma/1", ), ], any_order=True, )