Esempio n. 1
0
def test_create_trial(state: TrialState) -> None:
    value = 0.2
    params = {"x": 10}
    distributions = {"x": UniformDistribution(5, 12)}
    user_attrs = {"foo": "bar"}
    system_attrs = {"baz": "qux"}
    intermediate_values = {0: 0.0, 1: 0.1, 2: 0.1}

    trial = create_trial(
        state=state,
        value=value,
        params=params,
        distributions=distributions,
        user_attrs=user_attrs,
        system_attrs=system_attrs,
        intermediate_values=intermediate_values,
    )

    assert isinstance(trial, FrozenTrial)
    assert trial.state == state
    assert trial.value == value
    assert trial.params == params
    assert trial.distributions == distributions
    assert trial.user_attrs == user_attrs
    assert trial.system_attrs == system_attrs
    assert trial.intermediate_values == intermediate_values
    assert trial.datetime_start is not None
    assert (trial.datetime_complete is not None) == state.is_finished()

    with pytest.raises(ValueError):
        create_trial(state=state, value=value, values=(value, ))
Esempio n. 2
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        self._check_trial_id(trial_id)
        trial = self.get_trial(trial_id)
        self.check_trial_is_updatable(trial_id, trial.state)

        if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
            return False

        trial.state = state

        if state == TrialState.RUNNING:
            trial.datetime_start = datetime.now()

        if state.is_finished():
            trial.datetime_complete = datetime.now()
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
            self._update_cache(trial_id)

            # To ensure that there are no failed trials with heartbeats in the DB
            # under any circumstances
            study_id = self.get_study_id_from_trial_id(trial_id)
            self._redis.hdel(self._key_study_heartbeats(study_id),
                             str(trial_id))
        else:
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))

        return True
Esempio n. 3
0
    def set_trial_state_values(
            self,
            trial_id: int,
            state: TrialState,
            values: Optional[Sequence[float]] = None) -> bool:

        with self._lock:
            trial = copy.copy(self._get_trial(trial_id))
            self.check_trial_is_updatable(trial_id, trial.state)

            if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
                return False

            trial.state = state
            if values is not None:
                trial.values = values

            if state == TrialState.RUNNING:
                trial.datetime_start = datetime.now()

            if state.is_finished():
                trial.datetime_complete = datetime.now()
                self._set_trial(trial_id, trial)
                study_id = self._trial_id_to_study_id_and_number[trial_id][0]
                self._update_cache(trial_id, study_id)
            else:
                self._set_trial(trial_id, trial)

            return True
Esempio n. 4
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        with self._lock:
            trial = self._get_trial(trial_id)
            self.check_trial_is_updatable(trial_id, trial.state)

            trial = copy.copy(trial)
            self.check_trial_is_updatable(trial_id, trial.state)

            if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
                return False

            trial.state = state

            if state == TrialState.RUNNING:
                trial.datetime_start = datetime.now()

            if state.is_finished():
                trial.datetime_complete = datetime.now()
                self._set_trial(trial_id, trial)
                study_id = self._trial_id_to_study_id_and_number[trial_id][0]
                self._update_cache(trial_id, study_id)
            else:
                self._set_trial(trial_id, trial)

            return True
Esempio n. 5
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        self._check_trial_id(trial_id)
        trial = self.get_trial(trial_id)
        self.check_trial_is_updatable(trial_id, trial.state)

        if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
            return False

        trial.state = state
        if state.is_finished():
            trial.datetime_complete = datetime.now()
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
            self._update_cache(trial_id)
        else:
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))

        return True
Esempio n. 6
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:
        study_id = _id.get_study_id(trial_id)
        data = {
            "trial_id": trial_id,
            "state": state.value,
            "worker_id": self._worker_id()
        }
        if state.is_finished():
            data["datetime_complete"] = datetime.now().timestamp()

        self._enqueue_op(study_id, _Operation.SET_TRIAL_STATE, data)
        self._sync(study_id)

        trial = self.get_trial(trial_id)
        if state == TrialState.RUNNING and trial.owner != self._worker_id():
            return False

        return True
Esempio n. 7
0
    def check_trial_is_updatable(self, trial_id: int, trial_state: TrialState) -> None:
        """Check whether a trial state is updatable.

        Args:
            trial_id:
                ID of the trial.
                Only used for an error message.
            trial_state:
                Trial state to check.

        Raises:
            :exc:`RuntimeError`:
                If the trial is already finished.
        """
        if trial_state.is_finished():
            trial = self.get_trial(trial_id)
            raise RuntimeError(
                "Trial#{} has already finished and can not be updated.".format(trial.number)
            )
Esempio n. 8
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        try:
            with _create_scoped_session(self.scoped_session) as session:
                trial = models.TrialModel.find_or_raise_by_id(trial_id, session, for_update=True)
                self.check_trial_is_updatable(trial_id, trial.state)

                if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
                    return False

                trial.state = state

                if state == TrialState.RUNNING:
                    trial.datetime_start = datetime.now()

                if state.is_finished():
                    trial.datetime_complete = datetime.now()
        except IntegrityError:
            return False
        return True
Esempio n. 9
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        session = self.scoped_session()

        trial = models.TrialModel.find_by_id(trial_id, session, for_update=True)
        if trial is None:
            session.rollback()
            raise KeyError(models.NOT_FOUND_MSG)

        self.check_trial_is_updatable(trial_id, trial.state)

        if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
            session.rollback()
            return False

        trial.state = state
        if state.is_finished():
            trial.datetime_complete = datetime.now()

        return self._commit_with_integrity_check(session)
Esempio n. 10
0
    def _set_trial_state(self, data: Dict[str, Any], worker_id: str) -> None:
        number = _id.get_trial_number(data["trial_id"])
        trial = self.trials[number]

        state = TrialState(data["state"])
        if state == TrialState.RUNNING:
            if trial.owner != data["worker_id"]:
                if data["worker_id"] == worker_id:
                    raise RuntimeError(
                        "Trial {} cannot be modified from the owner.".format(
                            number))
                else:
                    return

            if self.trials[number].state != TrialState.WAITING:
                return

        if trial.state.is_finished():
            if data["worker_id"] == worker_id:
                raise RuntimeError(
                    "Trial {} has already been finished.".format(number))
            else:
                return

        trial.state = state
        if state.is_finished():
            trial.datetime_complete = datetime.fromtimestamp(
                data["datetime_complete"])
            trial.owner = None

        if state == TrialState.RUNNING:
            self.trials[number].owner = data["worker_id"]

        if state == TrialState.COMPLETE:
            if (self.best_trial is None
                    or (self.direction == optuna.study.StudyDirection.MINIMIZE
                        and trial.value < self.best_trial.value)
                    or (self.direction == optuna.study.StudyDirection.MAXIMIZE
                        and trial.value > self.best_trial.value)):
                self.best_trial = trial