Example #1
0
def test_create_trial(state: TrialState) -> None:
    value = 0.2
    params = {"x": 10}
    distributions = {"x": UniformDistribution(5, 12)}
    user_attrs = {"foo": "bar"}
    system_attrs = {"baz": "qux"}
    intermediate_values = {0: 0.0, 1: 0.1, 2: 0.1}

    trial = create_trial(
        state=state,
        value=value,
        params=params,
        distributions=distributions,
        user_attrs=user_attrs,
        system_attrs=system_attrs,
        intermediate_values=intermediate_values,
    )

    assert isinstance(trial, FrozenTrial)
    assert trial.state == state
    assert trial.value == value
    assert trial.params == params
    assert trial.distributions == distributions
    assert trial.user_attrs == user_attrs
    assert trial.system_attrs == system_attrs
    assert trial.intermediate_values == intermediate_values
    assert trial.datetime_start is not None
    assert (trial.datetime_complete is not None) == state.is_finished()

    with pytest.raises(ValueError):
        create_trial(state=state, value=value, values=(value, ))
Example #2
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        self._check_trial_id(trial_id)
        trial = self.get_trial(trial_id)
        self.check_trial_is_updatable(trial_id, trial.state)

        if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
            return False

        trial.state = state

        if state == TrialState.RUNNING:
            trial.datetime_start = datetime.now()

        if state.is_finished():
            trial.datetime_complete = datetime.now()
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
            self._update_cache(trial_id)

            # To ensure that there are no failed trials with heartbeats in the DB
            # under any circumstances
            study_id = self.get_study_id_from_trial_id(trial_id)
            self._redis.hdel(self._key_study_heartbeats(study_id),
                             str(trial_id))
        else:
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))

        return True
Example #3
0
    def set_trial_state_values(
            self,
            trial_id: int,
            state: TrialState,
            values: Optional[Sequence[float]] = None) -> bool:

        with self._lock:
            trial = copy.copy(self._get_trial(trial_id))
            self.check_trial_is_updatable(trial_id, trial.state)

            if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
                return False

            trial.state = state
            if values is not None:
                trial.values = values

            if state == TrialState.RUNNING:
                trial.datetime_start = datetime.now()

            if state.is_finished():
                trial.datetime_complete = datetime.now()
                self._set_trial(trial_id, trial)
                study_id = self._trial_id_to_study_id_and_number[trial_id][0]
                self._update_cache(trial_id, study_id)
            else:
                self._set_trial(trial_id, trial)

            return True
Example #4
0
    def __to_frozen_trial(self, trial):
        number = trial["number"]
        state = trial["state"]
        datetime_start = trial["datetime_start"]
        datetime_complete = trial["datetime_complete"]
        params = trial["params"]
        distributions = {}
        for p in params:
            distributions[p] = json_to_distribution(trial["distributions"][p])
        user_attrs = trial["user_attrs"]
        system_attrs = trial["system_attrs"]
        trial["intermediate_values"]
        intkey_val = {}
        for k in trial["intermediate_values"]:
            intk = int(k)
            intkey_val[k] = trial["intermediate_values"][k]
        intermediate_values = intkey_val
        trial_id = trial["trial_id"]
        value = trial["value"]

        ret = FrozenTrial(number=number,
                          state=TrialState(state),
                          datetime_start=datetime_start,
                          datetime_complete=datetime_complete,
                          params=params,
                          distributions=distributions,
                          user_attrs=user_attrs,
                          system_attrs=system_attrs,
                          intermediate_values=intermediate_values,
                          trial_id=trial_id,
                          value=value)
        return ret
Example #5
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        with self._lock:
            trial = self._get_trial(trial_id)
            self.check_trial_is_updatable(trial_id, trial.state)

            trial = copy.copy(trial)
            self.check_trial_is_updatable(trial_id, trial.state)

            if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
                return False

            trial.state = state

            if state == TrialState.RUNNING:
                trial.datetime_start = datetime.now()

            if state.is_finished():
                trial.datetime_complete = datetime.now()
                self._set_trial(trial_id, trial)
                study_id = self._trial_id_to_study_id_and_number[trial_id][0]
                self._update_cache(trial_id, study_id)
            else:
                self._set_trial(trial_id, trial)

            return True
Example #6
0
    def _create_trial(self, data: Dict[str, Any], worker_id: str) -> None:
        number = len(self.trials)
        trial_id = _id.make_trial_id(self.study_id, number)

        if "datetime_complete" in data:
            data["datetime_complete"] = datetime.fromtimestamp(
                data["datetime_complete"])

        state = TrialState(data.get("state", TrialState.RUNNING.value))

        owner = None
        if state == TrialState.RUNNING:
            owner = data["worker_id"]

        trial = _Trial(
            trial_id=trial_id,
            number=number,
            state=state,
            value=data.get("value"),
            datetime_start=datetime.fromtimestamp(data["datetime_start"]),
            datetime_complete=data.get("datetime_complete"),
            params=data.get("params", {}),
            distributions=data.get("distributions", {}),
            user_attrs=data.get("user_attrs", {}),
            system_attrs=data.get("system_attrs", {}),
            intermediate_values=data.get("intermediate_values", {}),
            owner=owner,
        )
        self.trials.append(trial)

        if data["worker_id"] == worker_id:
            self.last_created_trial_ids[worker_id] = trial_id
Example #7
0
    def _set_trial_state(self, data: Dict[str, Any], worker_id: str) -> None:
        number = _id.get_trial_number(data["trial_id"])
        trial = self.trials[number]

        state = TrialState(data["state"])
        if state == TrialState.RUNNING:
            if trial.owner != data["worker_id"]:
                if data["worker_id"] == worker_id:
                    raise RuntimeError(
                        "Trial {} cannot be modified from the owner.".format(
                            number))
                else:
                    return

            if self.trials[number].state != TrialState.WAITING:
                return

        if trial.state.is_finished():
            if data["worker_id"] == worker_id:
                raise RuntimeError(
                    "Trial {} has already been finished.".format(number))
            else:
                return

        trial.state = state
        if state.is_finished():
            trial.datetime_complete = datetime.fromtimestamp(
                data["datetime_complete"])
            trial.owner = None

        if state == TrialState.RUNNING:
            self.trials[number].owner = data["worker_id"]

        if state == TrialState.COMPLETE:
            if (self.best_trial is None
                    or (self.direction == optuna.study.StudyDirection.MINIMIZE
                        and trial.value < self.best_trial.value)
                    or (self.direction == optuna.study.StudyDirection.MAXIMIZE
                        and trial.value > self.best_trial.value)):
                self.best_trial = trial
Example #8
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        self._check_trial_id(trial_id)
        trial = self.get_trial(trial_id)
        self.check_trial_is_updatable(trial_id, trial.state)

        if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
            return False

        trial.state = state
        if state.is_finished():
            trial.datetime_complete = datetime.now()
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
            self._update_cache(trial_id)
        else:
            self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))

        return True
Example #9
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:
        study_id = _id.get_study_id(trial_id)
        data = {
            "trial_id": trial_id,
            "state": state.value,
            "worker_id": self._worker_id()
        }
        if state.is_finished():
            data["datetime_complete"] = datetime.now().timestamp()

        self._enqueue_op(study_id, _Operation.SET_TRIAL_STATE, data)
        self._sync(study_id)

        trial = self.get_trial(trial_id)
        if state == TrialState.RUNNING and trial.owner != self._worker_id():
            return False

        return True
Example #10
0
    def check_trial_is_updatable(self, trial_id: int, trial_state: TrialState) -> None:
        """Check whether a trial state is updatable.

        Args:
            trial_id:
                ID of the trial.
                Only used for an error message.
            trial_state:
                Trial state to check.

        Raises:
            :exc:`RuntimeError`:
                If the trial is already finished.
        """
        if trial_state.is_finished():
            trial = self.get_trial(trial_id)
            raise RuntimeError(
                "Trial#{} has already finished and can not be updated.".format(trial.number)
            )
Example #11
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        try:
            with _create_scoped_session(self.scoped_session) as session:
                trial = models.TrialModel.find_or_raise_by_id(trial_id, session, for_update=True)
                self.check_trial_is_updatable(trial_id, trial.state)

                if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
                    return False

                trial.state = state

                if state == TrialState.RUNNING:
                    trial.datetime_start = datetime.now()

                if state.is_finished():
                    trial.datetime_complete = datetime.now()
        except IntegrityError:
            return False
        return True
Example #12
0
    def set_trial_state(self, trial_id: int, state: TrialState) -> bool:

        session = self.scoped_session()

        trial = models.TrialModel.find_by_id(trial_id, session, for_update=True)
        if trial is None:
            session.rollback()
            raise KeyError(models.NOT_FOUND_MSG)

        self.check_trial_is_updatable(trial_id, trial.state)

        if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
            session.rollback()
            return False

        trial.state = state
        if state.is_finished():
            trial.datetime_complete = datetime.now()

        return self._commit_with_integrity_check(session)