Ejemplo n.º 1
0
    def _suggest(self, name: str, distribution: BaseDistribution) -> Any:

        storage = self.storage
        trial_id = self._trial_id

        trial = storage.get_trial(trial_id)

        if name in trial.distributions:
            # No need to sample if already suggested.
            distributions.check_distribution_compatibility(
                trial.distributions[name], distribution)
            param_value = distribution.to_external_repr(
                storage.get_trial_param(trial_id, name))
        else:
            if self._is_fixed_param(name, distribution):
                param_value = storage.get_trial_system_attrs(
                    trial_id)["fixed_params"][name]
            elif distribution.single():
                param_value = distributions._get_single_value(distribution)
            elif self._is_relative_param(name, distribution):
                param_value = self.relative_params[name]
            else:
                study = pruners._filter_study(self.study, trial)
                param_value = self.study.sampler.sample_independent(
                    study, trial, name, distribution)

            param_value_in_internal_repr = distribution.to_internal_repr(
                param_value)
            storage.set_trial_param(trial_id, name,
                                    param_value_in_internal_repr, distribution)

        return param_value
Ejemplo n.º 2
0
    def _after_func(self, state: TrialState, values: Optional[Sequence[float]]) -> None:
        # This method is called right before `Study._tell`.
        storage = self.storage
        trial_id = self._trial_id

        trial = storage.get_trial(trial_id)

        study = pruners._filter_study(self.study, trial)
        self.study.sampler.after_trial(study, trial, state, values)
Ejemplo n.º 3
0
    def _init_relative_params(self) -> None:

        trial = self.storage.get_trial(self._trial_id)

        study = pruners._filter_study(self.study, trial)

        self.relative_search_space = self.study.sampler.infer_relative_search_space(
            study, trial)
        self.relative_params = self.study.sampler.sample_relative(
            study, trial, self.relative_search_space)
Ejemplo n.º 4
0
    def _suggest(self, name, distribution):
        # type: (str, BaseDistribution) -> Any

        if self._is_fixed_param(name, distribution):
            param_value = self.storage.get_trial_system_attrs(self._trial_id)["fixed_params"][name]
        elif self._is_relative_param(name, distribution):
            param_value = self.relative_params[name]
        else:
            trial = self.storage.get_trial(self._trial_id)

            study = pruners._filter_study(self.study, trial)

            param_value = self.study.sampler.sample_independent(study, trial, name, distribution)

        return self._set_new_param_or_get_existing(name, param_value, distribution)
Ejemplo n.º 5
0
    def tell(
        self,
        trial: Union[trial_module.Trial, int],
        values: Optional[Union[float, Sequence[float]]] = None,
        state: TrialState = TrialState.COMPLETE,
    ) -> None:
        """Finish a trial created with :func:`~optuna.study.Study.ask`.

        Example:

            .. testcode::

                import optuna
                from optuna.trial import TrialState


                def f(x):
                    return (x - 2) ** 2


                def df(x):
                    return 2 * x - 4


                study = optuna.create_study()

                n_trials = 30

                for _ in range(n_trials):
                    trial = study.ask()

                    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)

                    # Iterative gradient descent objective function.
                    x = 3  # Initial value.
                    for step in range(128):
                        y = f(x)

                        trial.report(y, step=step)

                        if trial.should_prune():
                            # Finish the trial with the pruned state.
                            study.tell(trial, state=TrialState.PRUNED)
                            break

                        gy = df(x)
                        x -= gy * lr
                    else:
                        # Finish the trial with the final value after all iterations.
                        study.tell(trial, y)

        Args:
            trial:
                A :class:`~optuna.trial.Trial` object or a trial number.
            values:
                Optional objective value or a sequence of such values in case the study is used
                for multi-objective optimization. Argument must be provided if ``state`` is
                :class:`~optuna.trial.TrialState.COMPLETE` and should be :obj:`None` if ``state``
                is :class:`~optuna.trial.TrialState.FAIL` or
                :class:`~optuna.trial.TrialState.PRUNED`.
            state:
                State to be reported. Must be :class:`~optuna.trial.TrialState.COMPLETE`,
                :class:`~optuna.trial.TrialState.FAIL` or
                :class:`~optuna.trial.TrialState.PRUNED`.

        Raises:
            TypeError:
                If ``trial`` is not a :class:`~optuna.trial.Trial` or an :obj:`int`.
            ValueError:
                If any of the following.
                ``values`` is a sequence but its length does not match the number of objectives
                for its associated study.
                ``state`` is :class:`~optuna.trial.TrialState.COMPLETE` but
                ``values`` is :obj:`None`.
                ``state`` is :class:`~optuna.trial.TrialState.FAIL` or
                :class:`~optuna.trial.TrialState.PRUNED` but
                ``values`` is not :obj:`None`.
                ``state`` is not
                :class:`~optuna.trial.TrialState.COMPLETE`,
                :class:`~optuna.trial.TrialState.FAIL` or
                :class:`~optuna.trial.TrialState.PRUNED`.
                ``trial`` is a trial number but no
                trial exists with that number.
        """

        if not isinstance(trial, (trial_module.Trial, int)):
            raise TypeError("Trial must be a trial object or trial number.")

        if state == TrialState.COMPLETE:
            if values is None:
                raise ValueError(
                    "No values were told. Values are required when state is TrialState.COMPLETE."
                )
        elif state in (TrialState.PRUNED, TrialState.FAIL):
            if values is not None:
                raise ValueError(
                    "Values were told. Values cannot be specified when state is "
                    "TrialState.PRUNED or TrialState.FAIL.")
        else:
            raise ValueError(f"Cannot tell with state {state}.")

        if isinstance(trial, trial_module.Trial):
            trial_number = trial.number
            trial_id = trial._trial_id
        elif isinstance(trial, int):
            trial_number = trial
            try:
                trial_id = self._storage.get_trial_id_from_study_id_trial_number(
                    self._study_id, trial_number)
            except NotImplementedError as e:
                warnings.warn(
                    "Study.tell may be slow because the trial was represented by its number but "
                    f"the storage {self._storage.__class__.__name__} does not implement the "
                    "method required to map numbers back. Please provide the trial object "
                    "to avoid performance degradation.")

                trials = self.get_trials(deepcopy=False)

                if len(trials) <= trial_number:
                    raise ValueError(
                        f"Cannot tell for trial with number {trial_number} since it has not been "
                        "created.") from e

                trial_id = trials[trial_number]._trial_id
            except KeyError as e:
                raise ValueError(
                    f"Cannot tell for trial with number {trial_number} since it has not been "
                    "created.") from e
        else:
            assert False, "Should not reach."

        frozen_trial = self._storage.get_trial(trial_id)

        if state == TrialState.PRUNED:
            # Register the last intermediate value if present as the value of the trial.
            # TODO(hvy): Whether a pruned trials should have an actual value can be discussed.
            assert values is None

            last_step = frozen_trial.last_step
            if last_step is not None:
                values = [frozen_trial.intermediate_values[last_step]]

        if values is not None:
            values, values_conversion_failure_message = _check_and_convert_to_values(
                len(self.directions), values, trial_number)
            # When called from `Study.optimize` and `state` is pruned, the given `values` contains
            # the intermediate value with the largest step so far. In this case, the value is
            # allowed to be NaN and errors should not be raised.
            if state != TrialState.PRUNED and values_conversion_failure_message is not None:
                raise ValueError(values_conversion_failure_message)

        try:
            # Sampler defined trial post-processing.
            study = pruners._filter_study(self, frozen_trial)
            self.sampler.after_trial(study, frozen_trial, state, values)
        except Exception:
            raise
        finally:
            if values is not None:
                self._storage.set_trial_values(trial_id, values)

            self._storage.set_trial_state(trial_id, state)
Ejemplo n.º 6
0
def _tell_with_warning(
    study: "optuna.Study",
    trial: Union[trial_module.Trial, int],
    values: Optional[Union[float, Sequence[float]]] = None,
    state: Optional[TrialState] = None,
    skip_if_finished: bool = False,
    suppress_warning: bool = False,
) -> FrozenTrial:
    """Internal method of :func:`~optuna.study.Study.tell`.

    Refer to the document for :func:`~optuna.study.Study.tell` for the reference.
    This method has one additional parameter ``suppress_warning``.

    Args:
        suppress_warning:
            If :obj:`True`, tell will not show warnings when tell receives an invalid
            values. This flag is expected to be :obj:`True` only when it is invoked by
            Study.optimize.
    """

    if not isinstance(trial, (trial_module.Trial, int)):
        raise TypeError("Trial must be a trial object or trial number.")

    if state == TrialState.COMPLETE:
        if values is None:
            raise ValueError(
                "No values were told. Values are required when state is TrialState.COMPLETE."
            )
    elif state in (TrialState.PRUNED, TrialState.FAIL):
        if values is not None:
            raise ValueError(
                "Values were told. Values cannot be specified when state is "
                "TrialState.PRUNED or TrialState.FAIL."
            )
    elif state is not None:
        raise ValueError(f"Cannot tell with state {state}.")

    if isinstance(trial, trial_module.Trial):
        trial_number = trial.number
        trial_id = trial._trial_id
    elif isinstance(trial, int):
        trial_number = trial
        try:
            trial_id = study._storage.get_trial_id_from_study_id_trial_number(
                study._study_id, trial_number
            )
        except NotImplementedError as e:
            warnings.warn(
                "Study.tell may be slow because the trial was represented by its number but "
                f"the storage {study._storage.__class__.__name__} does not implement the "
                "method required to map numbers back. Please provide the trial object "
                "to avoid performance degradation."
            )

            trials = study.get_trials(deepcopy=False)

            if len(trials) <= trial_number:
                raise ValueError(
                    f"Cannot tell for trial with number {trial_number} since it has not been "
                    "created."
                ) from e

            trial_id = trials[trial_number]._trial_id
        except KeyError as e:
            raise ValueError(
                f"Cannot tell for trial with number {trial_number} since it has not been "
                "created."
            ) from e
    else:
        assert False, "Should not reach."

    frozen_trial = study._storage.get_trial(trial_id)
    warning_message = None

    if frozen_trial.state.is_finished() and skip_if_finished:
        _logger.info(
            f"Skipped telling trial {trial_number} with values "
            f"{values} and state {state} since trial was already finished. "
            f"Finished trial has values {frozen_trial.values} and state {frozen_trial.state}."
        )
        return copy.deepcopy(frozen_trial)

    if state == TrialState.PRUNED:
        # Register the last intermediate value if present as the value of the trial.
        # TODO(hvy): Whether a pruned trials should have an actual value can be discussed.
        assert values is None

        last_step = frozen_trial.last_step
        if last_step is not None:
            values = [frozen_trial.intermediate_values[last_step]]

    values, values_conversion_failure_message = _check_and_convert_to_values(
        len(study.directions), values, trial_number
    )

    if state is None:
        if values_conversion_failure_message is None:
            state = TrialState.COMPLETE
        else:
            state = TrialState.FAIL
            if not suppress_warning:
                warnings.warn(values_conversion_failure_message)
            else:
                warning_message = values_conversion_failure_message

    assert state is not None

    try:
        # Sampler defined trial post-processing.
        study = pruners._filter_study(study, frozen_trial)
        study.sampler.after_trial(study, frozen_trial, state, values)
    except Exception:
        raise
    finally:
        study._storage.set_trial_state_values(trial_id, state, values)

    frozen_trial = copy.deepcopy(study._storage.get_trial(trial_id))

    if warning_message is not None:
        frozen_trial.set_system_attr(STUDY_TELL_WARNING_KEY, warning_message)
    return frozen_trial