Пример #1
0
def test_retry_failed_trial_callback_intermediate(
        storage_mode: str, max_retry: Optional[int]) -> None:
    heartbeat_interval = 1
    grace_period = 2

    with StorageSupplier(
            storage_mode,
            heartbeat_interval=heartbeat_interval,
            grace_period=grace_period,
            failed_trial_callback=RetryFailedTrialCallback(
                max_retry=max_retry, inherit_intermediate_values=True),
    ) as storage:
        assert is_heartbeat_enabled(storage)
        assert isinstance(storage, BaseHeartbeat)

        study = optuna.create_study(storage=storage)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            trial = study.ask()
        trial.suggest_float("_", -1, -1)
        trial.report(0.5, 1)
        storage.record_heartbeat(trial._trial_id)
        time.sleep(grace_period + 1)

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.storages._heartbeat.Thread", _TestableThread):
            study.optimize(lambda _: 1.0, n_trials=1)

        # Test the last trial to see if it was a retry of the first trial or not.
        # Test max_retry=None to see if trial is retried.
        # Test max_retry=0 to see if no trials are retried.
        # Test max_retry=1 to see if trial is retried.
        assert RetryFailedTrialCallback.retried_trial_number(
            study.trials[1]) == (None if max_retry == 0 else 0)
        # Test inheritance of trial fields.
        if max_retry != 0:
            assert study.trials[0].params == study.trials[1].params
            assert study.trials[0].distributions == study.trials[
                1].distributions
            assert study.trials[0].user_attrs == study.trials[1].user_attrs
            assert study.trials[0].intermediate_values == study.trials[
                1].intermediate_values
Пример #2
0
def test_retry_failed_trial_callback_repetitive_failure(
        storage_mode: str) -> None:
    heartbeat_interval = 1
    grace_period = 2
    max_retry = 3
    n_trials = 5

    with StorageSupplier(
            storage_mode,
            heartbeat_interval=heartbeat_interval,
            grace_period=grace_period,
            failed_trial_callback=RetryFailedTrialCallback(
                max_retry=max_retry),
    ) as storage:
        assert is_heartbeat_enabled(storage)
        assert isinstance(storage, BaseHeartbeat)

        study = optuna.create_study(storage=storage)

        # Make repeatedly failed and retried trials by heartbeat.
        for _ in range(n_trials):
            trial = study.ask()
            storage.record_heartbeat(trial._trial_id)
            time.sleep(grace_period + 1)
            optuna.storages.fail_stale_trials(study)

        trials = study.trials

        assert len(trials) == n_trials + 1

        assert "failed_trial" not in trials[0].system_attrs
        assert "retry_history" not in trials[0].system_attrs

        # The trials 1-3 are retried ones originating from the trial 0.
        assert trials[1].system_attrs["failed_trial"] == 0
        assert trials[1].system_attrs["retry_history"] == [0]

        assert trials[2].system_attrs["failed_trial"] == 0
        assert trials[2].system_attrs["retry_history"] == [0, 1]

        assert trials[3].system_attrs["failed_trial"] == 0
        assert trials[3].system_attrs["retry_history"] == [0, 1, 2]

        # Trials 4 and later are the newly started ones and
        # they are retried after exceeding max_retry.
        assert "failed_trial" not in trials[4].system_attrs
        assert "retry_history" not in trials[4].system_attrs
        assert trials[5].system_attrs["failed_trial"] == 4
        assert trials[5].system_attrs["retry_history"] == [4]