def test_get_float_value() -> None: assert 1.0 == ChainerPruningExtension._get_float_value(1.0) assert 1.0 == ChainerPruningExtension._get_float_value(chainer.Variable(np.array([1.0]))) assert math.isnan(ChainerPruningExtension._get_float_value(float("nan"))) with pytest.raises(TypeError): ChainerPruningExtension._get_float_value([]) # type: ignore
def test_chainer_pruning_extension_trigger(): # type: () -> None study = optuna.create_study() trial = create_running_trial(study, 1.0) extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch')) assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger) extension = ChainerPruningExtension(trial, 'main/loss', triggers.IntervalTrigger(1, 'epoch')) assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger) extension = ChainerPruningExtension(trial, 'main/loss', triggers.ManualScheduleTrigger(1, 'epoch')) assert isinstance(extension.pruner_trigger, triggers.ManualScheduleTrigger) with pytest.raises(TypeError): ChainerPruningExtension(trial, 'main/loss', triggers.TimeTrigger(1.))
def test_chainer_pruning_extension_trigger() -> None: study = optuna.create_study() trial = study.ask() extension = ChainerPruningExtension(trial, "main/loss", (1, "epoch")) assert isinstance(extension._pruner_trigger, triggers.IntervalTrigger) extension = ChainerPruningExtension( trial, "main/loss", triggers.IntervalTrigger(1, "epoch") # type: ignore ) assert isinstance(extension._pruner_trigger, triggers.IntervalTrigger) extension = ChainerPruningExtension( trial, "main/loss", triggers.ManualScheduleTrigger(1, "epoch") # type: ignore ) assert isinstance(extension._pruner_trigger, triggers.ManualScheduleTrigger) with pytest.raises(TypeError): ChainerPruningExtension(trial, "main/loss", triggers.TimeTrigger(1.0)) # type: ignore
def test_chainer_pruning_extension_observation_isnan(): # type: () -> None study = optuna.create_study() trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, )) extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch')) MockTrainer = namedtuple('_MockTrainer', ('observation', )) trainer = MockTrainer(observation={'main/loss': float('nan')}) with patch.object(extension, '_observation_exists', Mock(return_value=True)) as mock: extension(trainer) assert mock.call_count == 1
def test_chainer_pruning_extension_observation_nan() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = study.ask() extension = ChainerPruningExtension(trial, "main/loss", (1, "epoch")) MockTrainer = namedtuple("MockTrainer", ("observation", "updater")) MockUpdater = namedtuple("MockUpdater", ("epoch")) trainer = MockTrainer(observation={"main/loss": float("nan")}, updater=MockUpdater(1)) with patch.object(extension, "_observation_exists", Mock(return_value=True)) as mock: with pytest.raises(optuna.TrialPruned): extension(trainer) # type: ignore assert mock.call_count == 1
def test_chainer_pruning_extension_observation_nan(): # type: () -> None study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch')) MockTrainer = namedtuple('_MockTrainer', ('observation', 'updater')) MockUpdater = namedtuple('_MockUpdater', ('epoch')) trainer = MockTrainer(observation={'main/loss': float('nan')}, updater=MockUpdater(1)) with patch.object(extension, '_observation_exists', Mock(return_value=True)) as mock: with pytest.raises(TrialPruned): extension(trainer) assert mock.call_count == 1
def test_observation_exists(): # type: () -> None study = optuna.create_study() trial = create_running_trial(study, 1.0) MockTrainer = namedtuple('_MockTrainer', ('observation', )) trainer = MockTrainer(observation={'OK': 0}) # Trigger is deactivated. Return False whether trainer has observation or not. with patch.object(triggers.IntervalTrigger, '__call__', Mock(return_value=False)) as mock: extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch')) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch')) assert extension._observation_exists(trainer) is False assert mock.call_count == 2 # Trigger is activated. Return True if trainer has observation. with patch.object(triggers.IntervalTrigger, '__call__', Mock(return_value=True)) as mock: extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch')) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch')) assert extension._observation_exists(trainer) is True assert mock.call_count == 2
def test_observation_exists() -> None: study = optuna.create_study() trial = study.ask() MockTrainer = namedtuple("MockTrainer", ("observation",)) trainer = MockTrainer(observation={"OK": 0}) # Trigger is deactivated. Return False whether trainer has observation or not. with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=False)) as mock: extension = ChainerPruningExtension(trial, "NG", (1, "epoch")) assert extension._observation_exists(trainer) is False # type: ignore extension = ChainerPruningExtension(trial, "OK", (1, "epoch")) assert extension._observation_exists(trainer) is False # type: ignore assert mock.call_count == 2 # Trigger is activated. Return True if trainer has observation. with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=True)) as mock: extension = ChainerPruningExtension(trial, "NG", (1, "epoch")) assert extension._observation_exists(trainer) is False # type: ignore extension = ChainerPruningExtension(trial, "OK", (1, "epoch")) assert extension._observation_exists(trainer) is True # type: ignore assert mock.call_count == 2
def test_get_float_value() -> None: assert 1.0 == ChainerPruningExtension._get_float_value(1.0) assert 1.0 == ChainerPruningExtension._get_float_value(chainer.Variable(np.array([1.0]))) assert math.isnan(ChainerPruningExtension._get_float_value(float("nan")))