def test_create_trial(state: TrialState) -> None: value = 0.2 params = {"x": 10} distributions = {"x": UniformDistribution(5, 12)} user_attrs = {"foo": "bar"} system_attrs = {"baz": "qux"} intermediate_values = {0: 0.0, 1: 0.1, 2: 0.1} trial = create_trial( state=state, value=value, params=params, distributions=distributions, user_attrs=user_attrs, system_attrs=system_attrs, intermediate_values=intermediate_values, ) assert isinstance(trial, FrozenTrial) assert trial.state == state assert trial.value == value assert trial.params == params assert trial.distributions == distributions assert trial.user_attrs == user_attrs assert trial.system_attrs == system_attrs assert trial.intermediate_values == intermediate_values assert trial.datetime_start is not None assert (trial.datetime_complete is not None) == state.is_finished() with pytest.raises(ValueError): create_trial(state=state, value=value, values=(value, ))
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: self._check_trial_id(trial_id) trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if state == TrialState.RUNNING: trial.datetime_start = datetime.now() if state.is_finished(): trial.datetime_complete = datetime.now() self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) self._update_cache(trial_id) # To ensure that there are no failed trials with heartbeats in the DB # under any circumstances study_id = self.get_study_id_from_trial_id(trial_id) self._redis.hdel(self._key_study_heartbeats(study_id), str(trial_id)) else: self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) return True
def set_trial_state_values( self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None) -> bool: with self._lock: trial = copy.copy(self._get_trial(trial_id)) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if values is not None: trial.values = values if state == TrialState.RUNNING: trial.datetime_start = datetime.now() if state.is_finished(): trial.datetime_complete = datetime.now() self._set_trial(trial_id, trial) study_id = self._trial_id_to_study_id_and_number[trial_id][0] self._update_cache(trial_id, study_id) else: self._set_trial(trial_id, trial) return True
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if state == TrialState.RUNNING: trial.datetime_start = datetime.now() if state.is_finished(): trial.datetime_complete = datetime.now() self._set_trial(trial_id, trial) study_id = self._trial_id_to_study_id_and_number[trial_id][0] self._update_cache(trial_id, study_id) else: self._set_trial(trial_id, trial) return True
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: self._check_trial_id(trial_id) trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if state.is_finished(): trial.datetime_complete = datetime.now() self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) self._update_cache(trial_id) else: self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) return True
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: study_id = _id.get_study_id(trial_id) data = { "trial_id": trial_id, "state": state.value, "worker_id": self._worker_id() } if state.is_finished(): data["datetime_complete"] = datetime.now().timestamp() self._enqueue_op(study_id, _Operation.SET_TRIAL_STATE, data) self._sync(study_id) trial = self.get_trial(trial_id) if state == TrialState.RUNNING and trial.owner != self._worker_id(): return False return True
def check_trial_is_updatable(self, trial_id: int, trial_state: TrialState) -> None: """Check whether a trial state is updatable. Args: trial_id: ID of the trial. Only used for an error message. trial_state: Trial state to check. Raises: :exc:`RuntimeError`: If the trial is already finished. """ if trial_state.is_finished(): trial = self.get_trial(trial_id) raise RuntimeError( "Trial#{} has already finished and can not be updated.".format(trial.number) )
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: try: with _create_scoped_session(self.scoped_session) as session: trial = models.TrialModel.find_or_raise_by_id(trial_id, session, for_update=True) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if state == TrialState.RUNNING: trial.datetime_start = datetime.now() if state.is_finished(): trial.datetime_complete = datetime.now() except IntegrityError: return False return True
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: session = self.scoped_session() trial = models.TrialModel.find_by_id(trial_id, session, for_update=True) if trial is None: session.rollback() raise KeyError(models.NOT_FOUND_MSG) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: session.rollback() return False trial.state = state if state.is_finished(): trial.datetime_complete = datetime.now() return self._commit_with_integrity_check(session)
def _set_trial_state(self, data: Dict[str, Any], worker_id: str) -> None: number = _id.get_trial_number(data["trial_id"]) trial = self.trials[number] state = TrialState(data["state"]) if state == TrialState.RUNNING: if trial.owner != data["worker_id"]: if data["worker_id"] == worker_id: raise RuntimeError( "Trial {} cannot be modified from the owner.".format( number)) else: return if self.trials[number].state != TrialState.WAITING: return if trial.state.is_finished(): if data["worker_id"] == worker_id: raise RuntimeError( "Trial {} has already been finished.".format(number)) else: return trial.state = state if state.is_finished(): trial.datetime_complete = datetime.fromtimestamp( data["datetime_complete"]) trial.owner = None if state == TrialState.RUNNING: self.trials[number].owner = data["worker_id"] if state == TrialState.COMPLETE: if (self.best_trial is None or (self.direction == optuna.study.StudyDirection.MINIMIZE and trial.value < self.best_trial.value) or (self.direction == optuna.study.StudyDirection.MAXIMIZE and trial.value > self.best_trial.value)): self.best_trial = trial