Exemple #1
0
def test_get_float_value() -> None:

    assert 1.0 == ChainerPruningExtension._get_float_value(1.0)
    assert 1.0 == ChainerPruningExtension._get_float_value(chainer.Variable(np.array([1.0])))
    assert math.isnan(ChainerPruningExtension._get_float_value(float("nan")))
    with pytest.raises(TypeError):
        ChainerPruningExtension._get_float_value([])  # type: ignore
Exemple #2
0
def test_chainer_pruning_extension_trigger():
    # type: () -> None

    study = optuna.create_study()
    trial = create_running_trial(study, 1.0)

    extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch'))
    assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger)
    extension = ChainerPruningExtension(trial, 'main/loss', triggers.IntervalTrigger(1, 'epoch'))
    assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger)
    extension = ChainerPruningExtension(trial, 'main/loss',
                                        triggers.ManualScheduleTrigger(1, 'epoch'))
    assert isinstance(extension.pruner_trigger, triggers.ManualScheduleTrigger)

    with pytest.raises(TypeError):
        ChainerPruningExtension(trial, 'main/loss', triggers.TimeTrigger(1.))
Exemple #3
0
def test_chainer_pruning_extension_trigger() -> None:

    study = optuna.create_study()
    trial = study.ask()

    extension = ChainerPruningExtension(trial, "main/loss", (1, "epoch"))
    assert isinstance(extension._pruner_trigger, triggers.IntervalTrigger)
    extension = ChainerPruningExtension(
        trial, "main/loss", triggers.IntervalTrigger(1, "epoch")  # type: ignore
    )
    assert isinstance(extension._pruner_trigger, triggers.IntervalTrigger)
    extension = ChainerPruningExtension(
        trial, "main/loss", triggers.ManualScheduleTrigger(1, "epoch")  # type: ignore
    )
    assert isinstance(extension._pruner_trigger, triggers.ManualScheduleTrigger)

    with pytest.raises(TypeError):
        ChainerPruningExtension(trial, "main/loss", triggers.TimeTrigger(1.0))  # type: ignore
Exemple #4
0
def test_chainer_pruning_extension_observation_isnan():
    # type: () -> None

    study = optuna.create_study()
    trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, ))
    extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch'))

    MockTrainer = namedtuple('_MockTrainer', ('observation', ))
    trainer = MockTrainer(observation={'main/loss': float('nan')})

    with patch.object(extension, '_observation_exists',
                      Mock(return_value=True)) as mock:
        extension(trainer)
        assert mock.call_count == 1
Exemple #5
0
def test_chainer_pruning_extension_observation_nan() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = study.ask()
    extension = ChainerPruningExtension(trial, "main/loss", (1, "epoch"))

    MockTrainer = namedtuple("MockTrainer", ("observation", "updater"))
    MockUpdater = namedtuple("MockUpdater", ("epoch"))
    trainer = MockTrainer(observation={"main/loss": float("nan")}, updater=MockUpdater(1))

    with patch.object(extension, "_observation_exists", Mock(return_value=True)) as mock:
        with pytest.raises(optuna.TrialPruned):
            extension(trainer)  # type: ignore
        assert mock.call_count == 1
Exemple #6
0
def test_chainer_pruning_extension_observation_nan():
    # type: () -> None

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch'))

    MockTrainer = namedtuple('_MockTrainer', ('observation', 'updater'))
    MockUpdater = namedtuple('_MockUpdater', ('epoch'))
    trainer = MockTrainer(observation={'main/loss': float('nan')}, updater=MockUpdater(1))

    with patch.object(extension, '_observation_exists', Mock(return_value=True)) as mock:
        with pytest.raises(TrialPruned):
            extension(trainer)
        assert mock.call_count == 1
Exemple #7
0
def test_observation_exists():
    # type: () -> None

    study = optuna.create_study()
    trial = create_running_trial(study, 1.0)
    MockTrainer = namedtuple('_MockTrainer', ('observation', ))
    trainer = MockTrainer(observation={'OK': 0})

    # Trigger is deactivated. Return False whether trainer has observation or not.
    with patch.object(triggers.IntervalTrigger, '__call__', Mock(return_value=False)) as mock:
        extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch'))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch'))
        assert extension._observation_exists(trainer) is False
        assert mock.call_count == 2

    # Trigger is activated. Return True if trainer has observation.
    with patch.object(triggers.IntervalTrigger, '__call__', Mock(return_value=True)) as mock:
        extension = ChainerPruningExtension(trial, 'NG', (1, 'epoch'))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, 'OK', (1, 'epoch'))
        assert extension._observation_exists(trainer) is True
        assert mock.call_count == 2
Exemple #8
0
def test_observation_exists() -> None:

    study = optuna.create_study()
    trial = study.ask()
    MockTrainer = namedtuple("MockTrainer", ("observation",))
    trainer = MockTrainer(observation={"OK": 0})

    # Trigger is deactivated. Return False whether trainer has observation or not.
    with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=False)) as mock:
        extension = ChainerPruningExtension(trial, "NG", (1, "epoch"))
        assert extension._observation_exists(trainer) is False  # type: ignore
        extension = ChainerPruningExtension(trial, "OK", (1, "epoch"))
        assert extension._observation_exists(trainer) is False  # type: ignore
        assert mock.call_count == 2

    # Trigger is activated. Return True if trainer has observation.
    with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=True)) as mock:
        extension = ChainerPruningExtension(trial, "NG", (1, "epoch"))
        assert extension._observation_exists(trainer) is False  # type: ignore
        extension = ChainerPruningExtension(trial, "OK", (1, "epoch"))
        assert extension._observation_exists(trainer) is True  # type: ignore
        assert mock.call_count == 2
Exemple #9
0
def test_get_float_value() -> None:

    assert 1.0 == ChainerPruningExtension._get_float_value(1.0)
    assert 1.0 == ChainerPruningExtension._get_float_value(chainer.Variable(np.array([1.0])))
    assert math.isnan(ChainerPruningExtension._get_float_value(float("nan")))