Exemple #1
0
    def _get_prepared_new_trial(self, study_id: int,
                                template_trial: Optional[FrozenTrial],
                                session: orm.Session) -> models.TrialModel:
        if template_trial is None:
            trial = models.TrialModel(study_id=study_id,
                                      number=None,
                                      state=TrialState.RUNNING)
        else:
            # Because only `RUNNING` trials can be updated,
            # we temporarily set the state of the new trial to `RUNNING`.
            # After all fields of the trial have been updated,
            # the state is set to `template_trial.state`.
            temp_state = TrialState.RUNNING

            trial = models.TrialModel(
                study_id=study_id,
                number=None,
                state=temp_state,
                value=template_trial.value,
                datetime_start=template_trial.datetime_start,
                datetime_complete=template_trial.datetime_complete,
            )

        session.add(trial)

        # Flush the session cache to reflect the above addition operation to
        # the current RDB transaction.
        #
        # Without flushing, the following operations (e.g, `_set_trial_param_without_commit`)
        # will fail because the target trial doesn't exist in the storage yet.
        session.flush()

        if template_trial is not None:
            for param_name, param_value in template_trial.params.items():
                distribution = template_trial.distributions[param_name]
                param_value_in_internal_repr = distribution.to_internal_repr(
                    param_value)
                self._set_trial_param_without_commit(
                    session, trial.trial_id, param_name,
                    param_value_in_internal_repr, distribution)

            for key, value in template_trial.user_attrs.items():
                self._set_trial_user_attr_without_commit(
                    session, trial.trial_id, key, value)

            for key, value in template_trial.system_attrs.items():
                self._set_trial_system_attr_without_commit(
                    session, trial.trial_id, key, value)

            for step, intermediate_value in template_trial.intermediate_values.items(
            ):
                self._set_trial_intermediate_value_without_commit(
                    session, trial.trial_id, step, intermediate_value)

            trial.state = template_trial.state

        trial.number = trial.count_past_trials(session)
        session.add(trial)

        return trial
Exemple #2
0
    def _create_new_trial(
            self,
            study_id: int,
            template_trial: Optional[FrozenTrial] = None) -> FrozenTrial:
        """Create a new trial and returns its trial_id and a :class:`~optuna.trial.FrozenTrial`.

        Args:
            study_id:
                Study id.
            template_trial:
                A :class:`~optuna.trial.FrozenTrial` with default values for trial attributes.

        Returns:
            A :class:`~optuna.trial.FrozenTrial` instance.

        """

        session = self.scoped_session()

        # Ensure that that study exists.
        models.StudyModel.find_or_raise_by_id(study_id, session)

        if template_trial is None:
            trial = models.TrialModel(study_id=study_id,
                                      number=None,
                                      state=TrialState.RUNNING)
        else:
            # Because only `RUNNING` trials can be updated,
            # we temporarily set the state of the new trial to `RUNNING`.
            # After all fields of the trial have been updated,
            # the state is set to `template_trial.state`.
            temp_state = TrialState.RUNNING

            trial = models.TrialModel(
                study_id=study_id,
                number=None,
                state=temp_state,
                value=template_trial.value,
                datetime_start=template_trial.datetime_start,
                datetime_complete=template_trial.datetime_complete,
            )

        session.add(trial)

        # Flush the session cache to reflect the above addition operation to
        # the current RDB transaction.
        #
        # Without flushing, the following operations (e.g, `_set_trial_param_without_commit`)
        # will fail because the target trial doesn't exist in the storage yet.
        session.flush()

        if template_trial is not None:
            for param_name, param_value in template_trial.params.items():
                distribution = template_trial.distributions[param_name]
                param_value_in_internal_repr = distribution.to_internal_repr(
                    param_value)
                self._set_trial_param_without_commit(
                    session, trial.trial_id, param_name,
                    param_value_in_internal_repr, distribution)

            for key, value in template_trial.user_attrs.items():
                self._set_trial_user_attr_without_commit(
                    session, trial.trial_id, key, value)

            for key, value in template_trial.system_attrs.items():
                self._set_trial_system_attr_without_commit(
                    session, trial.trial_id, key, value)

            for step, intermediate_value in template_trial.intermediate_values.items(
            ):
                self._set_trial_intermediate_value_without_commit(
                    session, trial.trial_id, step, intermediate_value)

            trial.state = template_trial.state

        trial.number = trial.count_past_trials(session)
        session.add(trial)

        self._commit(session)

        if template_trial:
            frozen = copy.deepcopy(template_trial)
            frozen.number = trial.number
            frozen.datetime_start = trial.datetime_start
            frozen._trial_id = trial.trial_id
        else:
            frozen = FrozenTrial(
                number=trial.number,
                state=trial.state,
                value=None,
                datetime_start=trial.datetime_start,
                datetime_complete=None,
                params={},
                distributions={},
                user_attrs={},
                system_attrs={},
                intermediate_values={},
                trial_id=trial.trial_id,
            )

        return frozen