Ejemplo n.º 1
0
    def _set_trial_intermediate_value_without_commit(self, session, trial_id,
                                                     step, intermediate_value):
        # type: (orm.Session, int, int, float) -> None

        trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
        self.check_trial_is_updatable(trial_id, trial.state)

        trial_value = models.TrialValueModel.find_by_trial_and_step(
            trial, step, session)
        if trial_value is None:
            trial_value = models.TrialValueModel(trial_id=trial_id,
                                                 step=step,
                                                 value=intermediate_value)
            session.add(trial_value)
        else:
            trial_value.value = intermediate_value
Ejemplo n.º 2
0
    def set_trial_intermediate_value(self, trial_id, step, intermediate_value):
        # type: (int, int, float) -> bool

        session = self.scoped_session()

        trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
        trial_value = models.TrialValueModel.find_by_trial_and_step(
            trial, step, session)
        if trial_value is not None:
            return False

        trial_value = models.TrialValueModel(trial_id=trial_id,
                                             step=step,
                                             value=intermediate_value)

        session.add(trial_value)
        commit_success = self._commit_with_integrity_check(session)

        return commit_success
Ejemplo n.º 3
0
    def _update_trial(
        self,
        trial_id: int,
        state: Optional[TrialState] = None,
        value: Optional[float] = None,
        intermediate_values: Optional[Dict[int, float]] = None,
        params: Optional[Dict[str, Any]] = None,
        distributions_: Optional[Dict[str,
                                      distributions.BaseDistribution]] = None,
        user_attrs: Optional[Dict[str, Any]] = None,
        system_attrs: Optional[Dict[str, Any]] = None,
        datetime_complete: Optional[datetime] = None,
    ) -> bool:
        """Sync latest trial updates to a database.

        Args:
            trial_id:
                Trial id of the trial to update.
            state:
                New state. None when there are no changes.
            value:
                New value. None when there are no changes.
            intermediate_values:
                New intermediate values. None when there are no updates.
            params:
                New parameter dictionary. None when there are no updates.
            distributions_:
                New parameter distributions. None when there are no updates.
            user_attrs:
                New user_attr. None when there are no updates.
            system_attrs:
                New system_attr. None when there are no updates.
            datetime_complete:
                Completion time of the trial. Set if and only if this method
                change the state of trial into one of the finished states.

        Returns:
            True when success.

        """

        session = self.scoped_session()
        trial_model = (session.query(models.TrialModel).filter(
            models.TrialModel.trial_id ==
            trial_id).with_for_update().one_or_none())
        if trial_model is None:
            session.rollback()
            raise KeyError(models.NOT_FOUND_MSG)
        if trial_model.state.is_finished():
            session.rollback()
            raise RuntimeError("Cannot change attributes of finished trial.")
        if (state and trial_model.state != state
                and state == TrialState.RUNNING
                and trial_model.state != TrialState.WAITING):
            session.rollback()
            return False

        if state:
            trial_model.state = state

        if datetime_complete:
            trial_model.datetime_complete = datetime_complete

        if value is not None:
            trial_model.value = value

        if user_attrs:
            trial_user_attrs = (session.query(
                models.TrialUserAttributeModel).filter(
                    models.TrialUserAttributeModel.trial_id ==
                    trial_id).with_for_update().all())
            trial_user_attrs_dict = {
                attr.key: attr
                for attr in trial_user_attrs
            }
            for k, v in user_attrs.items():
                if k in trial_user_attrs_dict:
                    trial_user_attrs_dict[k].value_json = json.dumps(v)
                    session.add(trial_user_attrs_dict[k])
            trial_model.user_attributes.extend(
                models.TrialUserAttributeModel(key=k, value_json=json.dumps(v))
                for k, v in user_attrs.items()
                if k not in trial_user_attrs_dict)
        if system_attrs:
            trial_system_attrs = (session.query(
                models.TrialSystemAttributeModel).filter(
                    models.TrialSystemAttributeModel.trial_id ==
                    trial_id).with_for_update().all())
            trial_system_attrs_dict = {
                attr.key: attr
                for attr in trial_system_attrs
            }
            for k, v in system_attrs.items():
                if k in trial_system_attrs_dict:
                    trial_system_attrs_dict[k].value_json = json.dumps(v)
                    session.add(trial_system_attrs_dict[k])
            trial_model.system_attributes.extend(
                models.TrialSystemAttributeModel(key=k,
                                                 value_json=json.dumps(v))
                for k, v in system_attrs.items()
                if k not in trial_system_attrs_dict)
        if intermediate_values:
            value_models = (session.query(models.TrialValueModel).filter(
                models.TrialValueModel.trial_id ==
                trial_id).with_for_update().all())
            value_dict = {
                value_model.step: value_model
                for value_model in value_models
            }
            for s, v in value_dict.items():
                if s in value_dict:
                    value_dict[s] = v
                    session.add(value_dict[s])
            trial_model.values.extend(
                models.TrialValueModel(step=s, value=v)
                for s, v in intermediate_values.items() if s not in value_dict)
        if params and distributions_:
            trial_param = (session.query(models.TrialParamModel).filter(
                models.TrialParamModel.trial_id == trial_id).all())
            param_keys = set(param.param_name for param in trial_param)
            trial_model.params.extend(
                models.TrialParamModel(
                    param_name=param_name,
                    param_value=param_value,
                    distribution_json=distributions.distribution_to_json(
                        distributions_[param_name]),
                ) for param_name, param_value in params.items()
                if param_name not in param_keys)
        session.add(trial_model)
        self._commit(session)

        return True