예제 #1
0
    def _get_registered_model(cls, session, name, eager=False):
        """
        :param eager: If ``True``, eagerly loads the registered model's tags.
                      If ``False``, these attributes are not eagerly loaded and
                      will be loaded when their corresponding object properties
                      are accessed from the resulting ``SqlRegisteredModel`` object.
        """
        _validate_model_name(name)
        query_options = cls._get_eager_registered_model_query_options() if eager else []
        rms = (
            session.query(SqlRegisteredModel)
            .options(*query_options)
            .filter(SqlRegisteredModel.name == name)
            .all()
        )

        if len(rms) == 0:
            raise MlflowException(
                "Registered Model with name={} not found".format(name), RESOURCE_DOES_NOT_EXIST
            )
        if len(rms) > 1:
            raise MlflowException(
                "Expected only 1 registered model with name={}. "
                "Found {}.".format(name, len(rms)),
                INVALID_STATE,
            )
        return rms[0]
예제 #2
0
    def rename_registered_model(self, name, new_name):
        """
        Rename the registered model.

        :param name: Registered model name.
        :param new_name: New proposed name.
        :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        """
        _validate_model_name(new_name)
        with self.ManagedSessionMaker() as session:
            sql_registered_model = self._get_registered_model(session, name)
            try:
                updated_time = now()
                sql_registered_model.name = new_name
                for sql_model_version in sql_registered_model.model_versions:
                    sql_model_version.name = new_name
                    sql_model_version.last_updated_time = updated_time
                sql_registered_model.last_updated_time = updated_time
                self._save_to_db(session, [sql_registered_model] +
                                 sql_registered_model.model_versions)
                session.flush()
                return sql_registered_model.to_mlflow_entity()
            except sqlalchemy.exc.IntegrityError as e:
                raise MlflowException(
                    "Registered Model (name={}) already exists. "
                    "Error: {}".format(new_name, str(e)),
                    RESOURCE_ALREADY_EXISTS,
                )
예제 #3
0
    def create_registered_model(self, name, tags=None):
        """
        Create a new registered model in backend store.

        :param name: Name of the new model. This is expected to be unique in the backend store.
        :param tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
                     instances associated with this registered model.
        :return: A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
                 created in the backend.
        """
        _validate_model_name(name)
        for tag in tags or []:
            _validate_registered_model_tag(tag.key, tag.value)
        with self.ManagedSessionMaker() as session:
            try:
                creation_time = now()
                registered_model = SqlRegisteredModel(
                    name=name,
                    creation_time=creation_time,
                    last_updated_time=creation_time)
                tags_dict = {}
                for tag in tags or []:
                    tags_dict[tag.key] = tag.value
                registered_model.registered_model_tags = [
                    SqlRegisteredModelTag(key=key, value=value)
                    for key, value in tags_dict.items()
                ]
                self._save_to_db(session, registered_model)
                session.flush()
                return registered_model.to_mlflow_entity()
            except sqlalchemy.exc.IntegrityError as e:
                raise MlflowException(
                    'Registered Model (name={}) already exists. '
                    'Error: {}'.format(name, str(e)), RESOURCE_ALREADY_EXISTS)
    def _get_sql_model_version(cls, session, name, version, eager=False):
        """
        :param eager: If ``True``, eagerly loads the model version's tags.
                      If ``False``, these attributes are not eagerly loaded and
                      will be loaded when their corresponding object properties
                      are accessed from the resulting ``SqlModelVersion`` object.
        """
        _validate_model_name(name)
        _validate_model_version(version)
        query_options = cls._get_eager_model_version_query_options(
        ) if eager else []
        conditions = [
            SqlModelVersion.name == name, SqlModelVersion.version == version,
            SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL
        ]
        versions = session.query(SqlModelVersion).options(
            *query_options).filter(*conditions).all()

        if len(versions) == 0:
            raise MlflowException(
                'Model Version (name={}, version={}) '
                'not found'.format(name, version), RESOURCE_DOES_NOT_EXIST)
        if len(versions) > 1:
            raise MlflowException(
                'Expected only 1 model version with (name={}, version={}). '
                'Found {}.'.format(name, version, len(versions)),
                INVALID_STATE)
        return versions[0]
    def create_model_version(self, name, source, run_id, tags=None):
        """
        Create a new model version from given source and run ID.

        :param name: Registered model name.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model.
        :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
                     instances associated with this model version.
        :return: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 created in the backend.
        """
        def next_version(sql_registered_model):
            if sql_registered_model.model_versions:
                return max(
                    [mv.version
                     for mv in sql_registered_model.model_versions]) + 1
            else:
                return 1

        _validate_model_name(name)
        for tag in tags or []:
            _validate_model_version_tag(tag.key, tag.value)
        with self.ManagedSessionMaker() as session:
            creation_time = now()
            for attempt in range(self.CREATE_MODEL_VERSION_RETRIES):
                try:
                    sql_registered_model = self._get_registered_model(
                        session, name)
                    sql_registered_model.last_updated_time = creation_time
                    version = next_version(sql_registered_model)
                    model_version = SqlModelVersion(
                        name=name,
                        version=version,
                        creation_time=creation_time,
                        last_updated_time=creation_time,
                        source=source,
                        run_id=run_id)
                    tags_dict = {}
                    for tag in tags or []:
                        tags_dict[tag.key] = tag.value
                    model_version.model_version_tags = [
                        SqlModelVersionTag(key=key, value=value)
                        for key, value in tags_dict.items()
                    ]
                    self._save_to_db(session,
                                     [sql_registered_model, model_version])
                    session.flush()
                    return model_version.to_mlflow_entity()
                except sqlalchemy.exc.IntegrityError:
                    more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1
                    _logger.info(
                        'Model Version creation error (name=%s) Retrying %s more time%s.',
                        name, str(more_retries),
                        's' if more_retries > 1 else '')
        raise MlflowException(
            'Model Version creation error (name={}). Giving up after '
            '{} attempts.'.format(name, self.CREATE_MODEL_VERSION_RETRIES))
예제 #6
0
    def set_registered_model_tag(self, name, tag):
        """
        Set a tag for the registered model.

        :param name: Registered model name.
        :param tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
        :return: None
        """
        _validate_model_name(name)
        _validate_registered_model_tag(tag.key, tag.value)
        with self.ManagedSessionMaker() as session:
            # check if registered model exists
            self._get_registered_model(session, name)
            session.merge(SqlRegisteredModelTag(name=name, key=tag.key, value=tag.value))
예제 #7
0
    def delete_registered_model_tag(self, name, key):
        """
        Delete a tag associated with the registered model.

        :param name: Registered model name.
        :param key: Registered model tag key.
        :return: None
        """
        _validate_model_name(name)
        _validate_tag_name(key)
        with self.ManagedSessionMaker() as session:
            # check if registered model exists
            self._get_registered_model(session, name)
            existing_tag = self._get_registered_model_tag(session, name, key)
            if existing_tag is not None:
                session.delete(existing_tag)
예제 #8
0
 def _get_sql_model_version(cls, session, name, version, eager=False):
     """
     :param eager: If ``True``, eagerly loads the model version's tags.
                   If ``False``, these attributes are not eagerly loaded and
                   will be loaded when their corresponding object properties
                   are accessed from the resulting ``SqlModelVersion`` object.
     """
     _validate_model_name(name)
     _validate_model_version(version)
     query_options = cls._get_eager_model_version_query_options() if eager else []
     conditions = [
         SqlModelVersion.name == name,
         SqlModelVersion.version == version,
         SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
     ]
     return cls._get_model_version_from_db(session, name, version, conditions, query_options)
예제 #9
0
    def delete_model_version_tag(self, name, version, key):
        """
        Delete a tag associated with the model version.

        :param name: Registered model name.
        :param version: Registered model version.
        :param key: Tag key.
        :return: None
        """
        _validate_model_name(name)
        _validate_model_version(version)
        _validate_tag_name(key)
        with self.ManagedSessionMaker() as session:
            # check if model version exists
            self._get_sql_model_version(session, name, version)
            existing_tag = self._get_model_version_tag(session, name, version, key)
            if existing_tag is not None:
                session.delete(existing_tag)
예제 #10
0
    def set_model_version_tag(self, name, version, tag):
        """
        Set a tag for the model version.

        :param name: Registered model name.
        :param version: Registered model version.
        :param tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
        :return: None
        """
        _validate_model_name(name)
        _validate_model_version(version)
        _validate_model_version_tag(tag.key, tag.value)
        with self.ManagedSessionMaker() as session:
            # check if model version exists
            self._get_sql_model_version(session, name, version)
            session.merge(
                SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value)
            )