Esempio n. 1
0
def test_observation_exists():
    # type: () -> None

    study = optuna.create_study()
    trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, ))
    MockTrainer = namedtuple('_MockTrainer', ('observation', ))
    trainer = MockTrainer(observation={'OK': 0})

    # Trigger is deactivated. Return False whether trainer has observation or not.
    with patch.object(triggers.IntervalTrigger, '__call__',
                      Mock(return_value=False)) as mock:
        extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch'))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch'))
        assert extension._observation_exists(trainer) is False
        assert mock.call_count == 2

    # Trigger is activated. Return True if trainer has observation.
    with patch.object(triggers.IntervalTrigger, '__call__',
                      Mock(return_value=True)) as mock:
        extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch'))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch'))
        assert extension._observation_exists(trainer) is True
        assert mock.call_count == 2
Esempio n. 2
0
def test_observation_exists():
    # type: () -> None

    study = optuna.create_study()
    trial = create_running_trial(study, 1.0)
    MockTrainer = namedtuple("_MockTrainer", ("observation", ))
    trainer = MockTrainer(observation={"OK": 0})

    # Trigger is deactivated. Return False whether trainer has observation or not.
    with patch.object(triggers.IntervalTrigger, "__call__",
                      Mock(return_value=False)) as mock:
        extension = ChainerPruningExtension(trial, "NG", (1, "epoch"))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, "OK", (1, "epoch"))
        assert extension._observation_exists(trainer) is False
        assert mock.call_count == 2

    # Trigger is activated. Return True if trainer has observation.
    with patch.object(triggers.IntervalTrigger, "__call__",
                      Mock(return_value=True)) as mock:
        extension = ChainerPruningExtension(trial, "NG", (1, "epoch"))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, "OK", (1, "epoch"))
        assert extension._observation_exists(trainer) is True
        assert mock.call_count == 2