Exemplo n.º 1
0
def test_optimizer_params_handler_wrong_setup():
    with pytest.raises(TypeError):
        OptimizerParamsHandler(optimizer=None)

    optimizer = MagicMock(spec=torch.optim.Optimizer)
    handler = OptimizerParamsHandler(optimizer=optimizer)

    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(
            RuntimeError,
            match="Handler OptimizerParamsHandler works only with WandBLogger"
    ):
        handler(mock_engine, mock_logger, Events.ITERATION_STARTED)
Exemplo n.º 2
0
def test_wandb_close():
    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=WandBLogger)
    mock_logger.log = MagicMock()
    mock_engine = MagicMock()
    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.close()
Exemplo n.º 3
0
def test_optimizer_params():
    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=WandBLogger)
    mock_logger.log = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log.assert_called_once_with({"lr/group_0": 0.01},
                                            step=123,
                                            sync=None)

    wrapper = OptimizerParamsHandler(optimizer,
                                     param_name="lr",
                                     tag="generator")
    mock_logger = MagicMock(spec=WandBLogger)
    mock_logger.log = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01},
                                            step=123,
                                            sync=None)