def set_trial_param( self, trial_id: int, param_name: str, param_value_internal: float, distribution: "distributions.BaseDistribution", ) -> bool: trial = self.get_trial(trial_id) if trial.owner != self._worker_id(): raise RuntimeError if trial.state != TrialState.RUNNING: raise RuntimeError if param_name in trial.params: return False study_id = _id.get_study_id(trial_id) param_value = distribution.to_external_repr(param_value_internal) trial.params[param_name] = param_value trial.distributions[param_name] = distribution data = { "trial_id": trial_id, "name": param_name, "value": param_value, "distribution": optuna.distributions.distribution_to_json(distribution), } self._enqueue_op(study_id, _Operation.SET_TRIAL_PARAM, data) return True
def set_trial_value(self, trial_id: int, value: float) -> None: trial = self.get_trial(trial_id) if trial.owner != self._worker_id(): raise RuntimeError if trial.state != TrialState.RUNNING: raise RuntimeError study_id = _id.get_study_id(trial_id) data = {"trial_id": trial_id, "value": value} self._enqueue_op(study_id, _Operation.SET_TRIAL_VALUE, data)
def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: trial = self.get_trial(trial_id) if trial.owner != self._worker_id(): raise RuntimeError if trial.state != TrialState.RUNNING: raise RuntimeError study_id = _id.get_study_id(trial_id) data = {"trial_id": trial_id, "key": key, "value": value} self._enqueue_op(study_id, _Operation.SET_TRIAL_USER_ATTR, data) self._sync(study_id)
def set_trial_state(self, trial_id: int, state: TrialState) -> bool: study_id = _id.get_study_id(trial_id) data = { "trial_id": trial_id, "state": state.value, "worker_id": self._worker_id() } if state.is_finished(): data["datetime_complete"] = datetime.now().timestamp() self._enqueue_op(study_id, _Operation.SET_TRIAL_STATE, data) self._sync(study_id) trial = self.get_trial(trial_id) if state == TrialState.RUNNING and trial.owner != self._worker_id(): return False return True
def set_trial_intermediate_value(self, trial_id: int, step: int, intermediate_value: float) -> bool: trial = self.get_trial(trial_id) if trial.owner != self._worker_id(): raise RuntimeError if trial.state != TrialState.RUNNING: raise RuntimeError study_id = _id.get_study_id(trial_id) data = { "trial_id": trial_id, "value": intermediate_value, "step": step } self._enqueue_op(study_id, _Operation.SET_TRIAL_INTERMEDIATE_VALUE, data) self._sync(study_id) return True
def study_id(self) -> int: return _id.get_study_id(self._trial_id)
def get_study_id_from_trial_id(self, trial_id: int) -> int: return _id.get_study_id(trial_id)
def get_trial(self, trial_id: int) -> "FrozenTrial": study_id = _id.get_study_id(trial_id) return self._studies[study_id].trials[_id.get_trial_number(trial_id)]