def test_optimizer_params_handler_wrong_setup(): with pytest.raises(TypeError): OptimizerParamsHandler(optimizer=None) optimizer = MagicMock(spec=torch.optim.Optimizer) handler = OptimizerParamsHandler(optimizer=optimizer) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises(TypeError, match="Handler OptimizerParamsHandler works only with MLflowLogger"): handler(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_optimizer_params(): optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with({"lr group_0": 0.01}, step=123) wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator") mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with({"generator lr group_0": 0.01}, step=123)