Ejemplo n.º 1
0
def test_failed_trial_callback(storage_mode: str) -> None:
    heartbeat_interval = 1
    grace_period = 2

    def _failed_trial_callback(study: Study, trial: FrozenTrial) -> None:
        assert study.system_attrs["test"] == "A"
        assert trial.system_attrs["test"] == "B"

    failed_trial_callback = Mock(wraps=_failed_trial_callback)

    with StorageSupplier(
            storage_mode,
            heartbeat_interval=heartbeat_interval,
            grace_period=grace_period,
            failed_trial_callback=failed_trial_callback,
    ) as storage:
        assert is_heartbeat_enabled(storage)
        assert isinstance(storage, BaseHeartbeat)

        study = optuna.create_study(storage=storage)
        study.set_system_attr("test", "A")

        with pytest.warns(UserWarning):
            trial = study.ask()
        trial.set_system_attr("test", "B")
        storage.record_heartbeat(trial._trial_id)
        time.sleep(grace_period + 1)

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.storages._heartbeat.Thread", _TestableThread):
            study.optimize(lambda _: 1.0, n_trials=1)
            failed_trial_callback.assert_called_once()
Ejemplo n.º 2
0
def test_fail_stale_trials_with_optimize(storage_mode: str) -> None:

    heartbeat_interval = 1
    grace_period = 2

    with StorageSupplier(storage_mode,
                         heartbeat_interval=heartbeat_interval,
                         grace_period=grace_period) as storage:
        assert is_heartbeat_enabled(storage)
        assert isinstance(storage, BaseHeartbeat)

        study1 = optuna.create_study(storage=storage)
        study2 = optuna.create_study(storage=storage)

        with pytest.warns(UserWarning):
            trial1 = study1.ask()
            trial2 = study2.ask()
        storage.record_heartbeat(trial1._trial_id)
        storage.record_heartbeat(trial2._trial_id)
        time.sleep(grace_period + 1)

        assert study1.trials[0].state is TrialState.RUNNING
        assert study2.trials[0].state is TrialState.RUNNING

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.storages._heartbeat.Thread", _TestableThread):
            study1.optimize(lambda _: 1.0, n_trials=1)

        assert study1.trials[
            0].state is TrialState.FAIL  # type: ignore [comparison-overlap]
        assert study2.trials[0].state is TrialState.RUNNING
Ejemplo n.º 3
0
def test_retry_failed_trial_callback_repetitive_failure(
        storage_mode: str) -> None:
    heartbeat_interval = 1
    grace_period = 2
    max_retry = 3
    n_trials = 5

    with StorageSupplier(
            storage_mode,
            heartbeat_interval=heartbeat_interval,
            grace_period=grace_period,
            failed_trial_callback=RetryFailedTrialCallback(
                max_retry=max_retry),
    ) as storage:
        assert is_heartbeat_enabled(storage)
        assert isinstance(storage, BaseHeartbeat)

        study = optuna.create_study(storage=storage)

        # Make repeatedly failed and retried trials by heartbeat.
        for _ in range(n_trials):
            trial = study.ask()
            storage.record_heartbeat(trial._trial_id)
            time.sleep(grace_period + 1)
            optuna.storages.fail_stale_trials(study)

        trials = study.trials

        assert len(trials) == n_trials + 1

        assert "failed_trial" not in trials[0].system_attrs
        assert "retry_history" not in trials[0].system_attrs

        # The trials 1-3 are retried ones originating from the trial 0.
        assert trials[1].system_attrs["failed_trial"] == 0
        assert trials[1].system_attrs["retry_history"] == [0]

        assert trials[2].system_attrs["failed_trial"] == 0
        assert trials[2].system_attrs["retry_history"] == [0, 1]

        assert trials[3].system_attrs["failed_trial"] == 0
        assert trials[3].system_attrs["retry_history"] == [0, 1, 2]

        # Trials 4 and later are the newly started ones and
        # they are retried after exceeding max_retry.
        assert "failed_trial" not in trials[4].system_attrs
        assert "retry_history" not in trials[4].system_attrs
        assert trials[5].system_attrs["failed_trial"] == 4
        assert trials[5].system_attrs["retry_history"] == [4]
Ejemplo n.º 4
0
def test_retry_failed_trial_callback_intermediate(
        storage_mode: str, max_retry: Optional[int]) -> None:
    heartbeat_interval = 1
    grace_period = 2

    with StorageSupplier(
            storage_mode,
            heartbeat_interval=heartbeat_interval,
            grace_period=grace_period,
            failed_trial_callback=RetryFailedTrialCallback(
                max_retry=max_retry, inherit_intermediate_values=True),
    ) as storage:
        assert is_heartbeat_enabled(storage)
        assert isinstance(storage, BaseHeartbeat)

        study = optuna.create_study(storage=storage)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            trial = study.ask()
        trial.suggest_float("_", -1, -1)
        trial.report(0.5, 1)
        storage.record_heartbeat(trial._trial_id)
        time.sleep(grace_period + 1)

        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.storages._heartbeat.Thread", _TestableThread):
            study.optimize(lambda _: 1.0, n_trials=1)

        # Test the last trial to see if it was a retry of the first trial or not.
        # Test max_retry=None to see if trial is retried.
        # Test max_retry=0 to see if no trials are retried.
        # Test max_retry=1 to see if trial is retried.
        assert RetryFailedTrialCallback.retried_trial_number(
            study.trials[1]) == (None if max_retry == 0 else 0)
        # Test inheritance of trial fields.
        if max_retry != 0:
            assert study.trials[0].params == study.trials[1].params
            assert study.trials[0].distributions == study.trials[
                1].distributions
            assert study.trials[0].user_attrs == study.trials[1].user_attrs
            assert study.trials[0].intermediate_values == study.trials[
                1].intermediate_values
Ejemplo n.º 5
0
def _run_trial(
    study: "optuna.Study",
    func: "optuna.study.study.ObjectiveFuncType",
    catch: Tuple[Type[Exception], ...],
) -> trial_module.FrozenTrial:
    if is_heartbeat_enabled(study._storage):
        optuna.storages.fail_stale_trials(study)

    trial = study.ask()

    state: Optional[TrialState] = None
    value_or_values: Optional[Union[float, Sequence[float]]] = None
    func_err: Optional[Union[Exception, KeyboardInterrupt]] = None
    func_err_fail_exc_info: Optional[Any] = None
    stop_event: Optional[Event] = None
    thread: Optional[Thread] = None

    if is_heartbeat_enabled(study._storage):
        assert isinstance(study._storage, BaseHeartbeat)
        heartbeat = study._storage
        stop_event = Event()
        thread = Thread(target=_record_heartbeat,
                        args=(trial._trial_id, heartbeat, stop_event))
        thread.start()

    try:
        value_or_values = func(trial)
    except exceptions.TrialPruned as e:
        # TODO(mamu): Handle multi-objective cases.
        state = TrialState.PRUNED
        func_err = e
    except (Exception, KeyboardInterrupt) as e:
        state = TrialState.FAIL
        func_err = e
        func_err_fail_exc_info = sys.exc_info()

    if is_heartbeat_enabled(study._storage):
        assert stop_event is not None
        assert thread is not None
        stop_event.set()
        thread.join()

    # `_tell_with_warning` may raise during trial post-processing.
    try:
        frozen_trial = _tell_with_warning(study=study,
                                          trial=trial,
                                          values=value_or_values,
                                          state=state,
                                          suppress_warning=True)
    except Exception:
        frozen_trial = study._storage.get_trial(trial._trial_id)
        raise
    finally:
        if frozen_trial.state == TrialState.COMPLETE:
            study._log_completed_trial(frozen_trial)
        elif frozen_trial.state == TrialState.PRUNED:
            _logger.info("Trial {} pruned. {}".format(frozen_trial.number,
                                                      str(func_err)))
        elif frozen_trial.state == TrialState.FAIL:
            if func_err is not None:
                _log_failed_trial(frozen_trial,
                                  repr(func_err),
                                  exc_info=func_err_fail_exc_info)
            elif STUDY_TELL_WARNING_KEY in frozen_trial.system_attrs:
                _log_failed_trial(
                    frozen_trial,
                    frozen_trial.system_attrs[STUDY_TELL_WARNING_KEY])
            else:
                assert False, "Should not reach."
        else:
            assert False, "Should not reach."

    if (frozen_trial.state == TrialState.FAIL and func_err is not None
            and not isinstance(func_err, catch)):
        raise func_err
    return frozen_trial