Ejemplo n.º 1
0
    def _set_trial_param_without_commit(self, session, trial_id, param_name,
                                        param_value_internal, distribution):
        # type: (orm.Session, int, str, float, distributions.BaseDistribution) -> bool

        trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
        self.check_trial_is_updatable(trial_id, trial.state)

        trial_param = \
            models.TrialParamModel.find_by_trial_and_param_name(trial, param_name, session)

        if trial_param is not None:
            # Raise error in case distribution is incompatible.
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(
                    trial_param.distribution_json), distribution)

            # Terminate transaction explicitly to avoid connection timeout during transaction.
            self._commit(session)
            # Return False when distribution is compatible but parameter has already been set.
            return False

        param = models.TrialParamModel(
            trial_id=trial_id,
            param_name=param_name,
            param_value=param_value_internal,
            distribution_json=distributions.distribution_to_json(distribution))

        param.check_and_add(session)

        return True
Ejemplo n.º 2
0
    def set_trial_param(self, trial_id, param_name, param_value_internal,
                        distribution):
        # type: (int, str, float, distributions.BaseDistribution) -> bool

        session = self.scoped_session()

        trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
        trial_param = \
            models.TrialParamModel.find_by_trial_and_param_name(trial, param_name, session)

        if trial_param is not None:
            # Raise error in case distribution is incompatible.
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(
                    trial_param.distribution_json), distribution)

            # Return False when distribution is compatible but parameter has already been set.
            return False

        param = models.TrialParamModel(
            trial_id=trial_id,
            param_name=param_name,
            param_value=param_value_internal,
            distribution_json=distributions.distribution_to_json(distribution))

        param.check_and_add(session)
        commit_success = self._commit_with_integrity_check(session)

        return commit_success
Ejemplo n.º 3
0
    def _set_trial_param_without_commit(self, session, trial_id, param_name,
                                        param_value_internal, distribution):
        # type: (orm.Session, int, str, float, distributions.BaseDistribution) -> None

        trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
        self.check_trial_is_updatable(trial_id, trial.state)

        trial_param = models.TrialParamModel.find_by_trial_and_param_name(
            trial, param_name, session)

        if trial_param is not None:
            # Raise error in case distribution is incompatible.
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(
                    trial_param.distribution_json), distribution)

            trial_param.param_value = param_value_internal
            trial_param.distribution_json = distributions.distribution_to_json(
                distribution)
        else:
            trial_param = models.TrialParamModel(
                trial_id=trial_id,
                param_name=param_name,
                param_value=param_value_internal,
                distribution_json=distributions.distribution_to_json(
                    distribution),
            )

            trial_param.check_and_add(session)
Ejemplo n.º 4
0
    def _check_or_set_param_distribution(
        self,
        trial_id: int,
        param_name: str,
        param_value_internal: float,
        distribution: distributions.BaseDistribution,
    ) -> None:

        session = self.scoped_session()

        # Acquire a lock of this trial.
        trial = models.TrialModel.find_by_id(trial_id,
                                             session,
                                             for_update=True)
        if trial is None:
            raise KeyError(models.NOT_FOUND_MSG)

        previous_record = (session.query(models.TrialParamModel).join(
            models.TrialModel).filter(
                models.TrialModel.study_id == trial.study_id).filter(
                    models.TrialParamModel.param_name == param_name).first())
        if previous_record is not None:
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(
                    previous_record.distribution_json),
                distribution,
            )
        else:
            session.add(
                models.TrialParamModel(
                    trial_id=trial_id,
                    param_name=param_name,
                    param_value=param_value_internal,
                    distribution_json=distributions.distribution_to_json(
                        distribution),
                ))

        # Release lock.
        session.commit()
Ejemplo n.º 5
0
    def _update_trial(
        self,
        trial_id: int,
        state: Optional[TrialState] = None,
        value: Optional[float] = None,
        intermediate_values: Optional[Dict[int, float]] = None,
        params: Optional[Dict[str, Any]] = None,
        distributions_: Optional[Dict[str,
                                      distributions.BaseDistribution]] = None,
        user_attrs: Optional[Dict[str, Any]] = None,
        system_attrs: Optional[Dict[str, Any]] = None,
        datetime_complete: Optional[datetime] = None,
    ) -> bool:
        """Sync latest trial updates to a database.

        Args:
            trial_id:
                Trial id of the trial to update.
            state:
                New state. None when there are no changes.
            value:
                New value. None when there are no changes.
            intermediate_values:
                New intermediate values. None when there are no updates.
            params:
                New parameter dictionary. None when there are no updates.
            distributions_:
                New parameter distributions. None when there are no updates.
            user_attrs:
                New user_attr. None when there are no updates.
            system_attrs:
                New system_attr. None when there are no updates.
            datetime_complete:
                Completion time of the trial. Set if and only if this method
                change the state of trial into one of the finished states.

        Returns:
            True when success.

        """

        session = self.scoped_session()
        trial_model = (session.query(models.TrialModel).filter(
            models.TrialModel.trial_id ==
            trial_id).with_for_update().one_or_none())
        if trial_model is None:
            session.rollback()
            raise KeyError(models.NOT_FOUND_MSG)
        if trial_model.state.is_finished():
            session.rollback()
            raise RuntimeError("Cannot change attributes of finished trial.")
        if (state and trial_model.state != state
                and state == TrialState.RUNNING
                and trial_model.state != TrialState.WAITING):
            session.rollback()
            return False

        if state:
            trial_model.state = state

        if datetime_complete:
            trial_model.datetime_complete = datetime_complete

        if value is not None:
            trial_model.value = value

        if user_attrs:
            trial_user_attrs = (session.query(
                models.TrialUserAttributeModel).filter(
                    models.TrialUserAttributeModel.trial_id ==
                    trial_id).with_for_update().all())
            trial_user_attrs_dict = {
                attr.key: attr
                for attr in trial_user_attrs
            }
            for k, v in user_attrs.items():
                if k in trial_user_attrs_dict:
                    trial_user_attrs_dict[k].value_json = json.dumps(v)
                    session.add(trial_user_attrs_dict[k])
            trial_model.user_attributes.extend(
                models.TrialUserAttributeModel(key=k, value_json=json.dumps(v))
                for k, v in user_attrs.items()
                if k not in trial_user_attrs_dict)
        if system_attrs:
            trial_system_attrs = (session.query(
                models.TrialSystemAttributeModel).filter(
                    models.TrialSystemAttributeModel.trial_id ==
                    trial_id).with_for_update().all())
            trial_system_attrs_dict = {
                attr.key: attr
                for attr in trial_system_attrs
            }
            for k, v in system_attrs.items():
                if k in trial_system_attrs_dict:
                    trial_system_attrs_dict[k].value_json = json.dumps(v)
                    session.add(trial_system_attrs_dict[k])
            trial_model.system_attributes.extend(
                models.TrialSystemAttributeModel(key=k,
                                                 value_json=json.dumps(v))
                for k, v in system_attrs.items()
                if k not in trial_system_attrs_dict)
        if intermediate_values:
            value_models = (session.query(models.TrialValueModel).filter(
                models.TrialValueModel.trial_id ==
                trial_id).with_for_update().all())
            value_dict = {
                value_model.step: value_model
                for value_model in value_models
            }
            for s, v in value_dict.items():
                if s in value_dict:
                    value_dict[s] = v
                    session.add(value_dict[s])
            trial_model.values.extend(
                models.TrialValueModel(step=s, value=v)
                for s, v in intermediate_values.items() if s not in value_dict)
        if params and distributions_:
            trial_param = (session.query(models.TrialParamModel).filter(
                models.TrialParamModel.trial_id == trial_id).all())
            param_keys = set(param.param_name for param in trial_param)
            trial_model.params.extend(
                models.TrialParamModel(
                    param_name=param_name,
                    param_value=param_value,
                    distribution_json=distributions.distribution_to_json(
                        distributions_[param_name]),
                ) for param_name, param_value in params.items()
                if param_name not in param_keys)
        session.add(trial_model)
        self._commit(session)

        return True