def test_observation_exists(): # type: () -> None study = optuna.create_study() trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, )) MockTrainer = namedtuple('_MockTrainer', ('observation', )) trainer = MockTrainer(observation={'OK': 0}) # Trigger is deactivated. Return False whether trainer has observation or not. with patch.object(triggers.IntervalTrigger, '__call__', Mock(return_value=False)) as mock: extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch')) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch')) assert extension._observation_exists(trainer) is False assert mock.call_count == 2 # Trigger is activated. Return True if trainer has observation. with patch.object(triggers.IntervalTrigger, '__call__', Mock(return_value=True)) as mock: extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch')) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch')) assert extension._observation_exists(trainer) is True assert mock.call_count == 2
def test_observation_exists(): # type: () -> None study = optuna.create_study() trial = create_running_trial(study, 1.0) MockTrainer = namedtuple("_MockTrainer", ("observation", )) trainer = MockTrainer(observation={"OK": 0}) # Trigger is deactivated. Return False whether trainer has observation or not. with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=False)) as mock: extension = ChainerPruningExtension(trial, "NG", (1, "epoch")) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, "OK", (1, "epoch")) assert extension._observation_exists(trainer) is False assert mock.call_count == 2 # Trigger is activated. Return True if trainer has observation. with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=True)) as mock: extension = ChainerPruningExtension(trial, "NG", (1, "epoch")) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, "OK", (1, "epoch")) assert extension._observation_exists(trainer) is True assert mock.call_count == 2