def test_optimizer_params(): optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) # mock_logger.vis.line.assert_called_once_with("lr/group_0", 0.01, 123) assert len(wrapper.windows) == 1 and "lr/group_0" in wrapper.windows assert wrapper.windows["lr/group_0"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 0.01, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["lr/group_0"]["opts"], name="lr/group_0", ) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr", tag="generator") mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert len( wrapper.windows) == 1 and "generator/lr/group_0" in wrapper.windows assert wrapper.windows["generator/lr/group_0"]["win"] is not None mock_logger.vis.line.assert_called_once_with( X=[ 123, ], Y=[ 0.01, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["generator/lr/group_0"]["opts"], name="generator/lr/group_0", )
def test_optimizer_params_handler_wrong_setup(): with pytest.raises(TypeError): OptimizerParamsHandler(optimizer=None) optimizer = MagicMock(spec=torch.optim.Optimizer) handler = OptimizerParamsHandler(optimizer=optimizer) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises( RuntimeError, match="Handler OptimizerParamsHandler works only with VisdomLogger" ): handler(mock_engine, mock_logger, Events.ITERATION_STARTED)