コード例 #1
0
def test_optimizer_params():

    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    # mock_logger.vis.line.assert_called_once_with("lr/group_0", 0.01, 123)
    assert len(wrapper.windows) == 1 and "lr/group_0" in wrapper.windows
    assert wrapper.windows["lr/group_0"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            0.01,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["lr/group_0"]["opts"],
        name="lr/group_0",
    )

    wrapper = OptimizerParamsHandler(optimizer=optimizer,
                                     param_name="lr",
                                     tag="generator")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(
        wrapper.windows) == 1 and "generator/lr/group_0" in wrapper.windows
    assert wrapper.windows["generator/lr/group_0"]["win"] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[
            123,
        ],
        Y=[
            0.01,
        ],
        env=mock_logger.vis.env,
        win=None,
        update=None,
        opts=wrapper.windows["generator/lr/group_0"]["opts"],
        name="generator/lr/group_0",
    )
コード例 #2
0
def test_optimizer_params_handler_wrong_setup():

    with pytest.raises(TypeError):
        OptimizerParamsHandler(optimizer=None)

    optimizer = MagicMock(spec=torch.optim.Optimizer)
    handler = OptimizerParamsHandler(optimizer=optimizer)

    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(
            RuntimeError,
            match="Handler OptimizerParamsHandler works only with VisdomLogger"
    ):
        handler(mock_engine, mock_logger, Events.ITERATION_STARTED)