def _set_trial_system_attr_without_commit(self, session, trial_id, key, value): # type: (orm.Session, int, str, Any) -> None trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) attribute = models.TrialSystemAttributeModel.find_by_trial_and_key(trial, key, session) if attribute is None: attribute = models.TrialSystemAttributeModel( trial_id=trial_id, key=key, value_json=json.dumps(value) ) session.add(attribute) else: attribute.value_json = json.dumps(value)
def _update_trial( self, trial_id: int, state: Optional[TrialState] = None, value: Optional[float] = None, intermediate_values: Optional[Dict[int, float]] = None, params: Optional[Dict[str, Any]] = None, distributions_: Optional[Dict[str, distributions.BaseDistribution]] = None, user_attrs: Optional[Dict[str, Any]] = None, system_attrs: Optional[Dict[str, Any]] = None, datetime_complete: Optional[datetime] = None, ) -> bool: """Sync latest trial updates to a database. Args: trial_id: Trial id of the trial to update. state: New state. None when there are no changes. value: New value. None when there are no changes. intermediate_values: New intermediate values. None when there are no updates. params: New parameter dictionary. None when there are no updates. distributions_: New parameter distributions. None when there are no updates. user_attrs: New user_attr. None when there are no updates. system_attrs: New system_attr. None when there are no updates. datetime_complete: Completion time of the trial. Set if and only if this method change the state of trial into one of the finished states. Returns: True when success. """ session = self.scoped_session() trial_model = (session.query(models.TrialModel).filter( models.TrialModel.trial_id == trial_id).one_or_none()) if trial_model is None: session.rollback() raise KeyError(models.NOT_FOUND_MSG) if trial_model.state.is_finished(): session.rollback() raise RuntimeError("Cannot change attributes of finished trial.") if (state and trial_model.state != state and state == TrialState.RUNNING and trial_model.state != TrialState.WAITING): session.rollback() return False if state: trial_model.state = state if datetime_complete: trial_model.datetime_complete = datetime_complete if value is not None: trial_model.value = value if user_attrs: trial_user_attrs = (session.query( models.TrialUserAttributeModel).filter( models.TrialUserAttributeModel.trial_id == trial_id).all()) trial_user_attrs_dict = { attr.key: attr for attr in trial_user_attrs } for k, v in user_attrs.items(): if k in trial_user_attrs_dict: trial_user_attrs_dict[k].value_json = json.dumps(v) session.add(trial_user_attrs_dict[k]) trial_model.user_attributes.extend( models.TrialUserAttributeModel(key=k, value_json=json.dumps(v)) for k, v in user_attrs.items() if k not in trial_user_attrs_dict) if system_attrs: trial_system_attrs = (session.query( models.TrialSystemAttributeModel).filter( models.TrialSystemAttributeModel.trial_id == trial_id).all()) trial_system_attrs_dict = { attr.key: attr for attr in trial_system_attrs } for k, v in system_attrs.items(): if k in trial_system_attrs_dict: trial_system_attrs_dict[k].value_json = json.dumps(v) session.add(trial_system_attrs_dict[k]) trial_model.system_attributes.extend( models.TrialSystemAttributeModel(key=k, value_json=json.dumps(v)) for k, v in system_attrs.items() if k not in trial_system_attrs_dict) if intermediate_values: value_models = (session.query(models.TrialValueModel).filter( models.TrialValueModel.trial_id == trial_id).all()) value_dict = { value_model.step: value_model for value_model in value_models } for s, v in intermediate_values.items(): if s in value_dict: value_dict[s].value = v session.add(value_dict[s]) trial_model.values.extend( models.TrialValueModel(step=s, value=v) for s, v in intermediate_values.items() if s not in value_dict) if params and distributions_: trial_param = (session.query(models.TrialParamModel).filter( models.TrialParamModel.trial_id == trial_id).all()) trial_param_dict = {attr.param_name: attr for attr in trial_param} for name, v in params.items(): if name in trial_param_dict: trial_param_dict[ name].distribution_json = distributions.distribution_to_json( distributions_[name]) trial_param_dict[name].param_value = v session.add(trial_param_dict[name]) trial_model.params.extend( models.TrialParamModel( param_name=param_name, param_value=param_value, distribution_json=distributions.distribution_to_json( distributions_[param_name]), ) for param_name, param_value in params.items() if param_name not in trial_param_dict) session.add(trial_model) self._commit(session) return True