def _get_registered_model(cls, session, name, eager=False): """ :param eager: If ``True``, eagerly loads the registered model's tags. If ``False``, these attributes are not eagerly loaded and will be loaded when their corresponding object properties are accessed from the resulting ``SqlRegisteredModel`` object. """ _validate_model_name(name) query_options = cls._get_eager_registered_model_query_options() if eager else [] rms = ( session.query(SqlRegisteredModel) .options(*query_options) .filter(SqlRegisteredModel.name == name) .all() ) if len(rms) == 0: raise MlflowException( "Registered Model with name={} not found".format(name), RESOURCE_DOES_NOT_EXIST ) if len(rms) > 1: raise MlflowException( "Expected only 1 registered model with name={}. " "Found {}.".format(name, len(rms)), INVALID_STATE, ) return rms[0]
def rename_registered_model(self, name, new_name): """ Rename the registered model. :param name: Registered model name. :param new_name: New proposed name. :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. """ _validate_model_name(new_name) with self.ManagedSessionMaker() as session: sql_registered_model = self._get_registered_model(session, name) try: updated_time = now() sql_registered_model.name = new_name for sql_model_version in sql_registered_model.model_versions: sql_model_version.name = new_name sql_model_version.last_updated_time = updated_time sql_registered_model.last_updated_time = updated_time self._save_to_db(session, [sql_registered_model] + sql_registered_model.model_versions) session.flush() return sql_registered_model.to_mlflow_entity() except sqlalchemy.exc.IntegrityError as e: raise MlflowException( "Registered Model (name={}) already exists. " "Error: {}".format(new_name, str(e)), RESOURCE_ALREADY_EXISTS, )
def create_registered_model(self, name, tags=None): """ Create a new registered model in backend store. :param name: Name of the new model. This is expected to be unique in the backend store. :param tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instances associated with this registered model. :return: A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel` created in the backend. """ _validate_model_name(name) for tag in tags or []: _validate_registered_model_tag(tag.key, tag.value) with self.ManagedSessionMaker() as session: try: creation_time = now() registered_model = SqlRegisteredModel( name=name, creation_time=creation_time, last_updated_time=creation_time) tags_dict = {} for tag in tags or []: tags_dict[tag.key] = tag.value registered_model.registered_model_tags = [ SqlRegisteredModelTag(key=key, value=value) for key, value in tags_dict.items() ] self._save_to_db(session, registered_model) session.flush() return registered_model.to_mlflow_entity() except sqlalchemy.exc.IntegrityError as e: raise MlflowException( 'Registered Model (name={}) already exists. ' 'Error: {}'.format(name, str(e)), RESOURCE_ALREADY_EXISTS)
def _get_sql_model_version(cls, session, name, version, eager=False): """ :param eager: If ``True``, eagerly loads the model version's tags. If ``False``, these attributes are not eagerly loaded and will be loaded when their corresponding object properties are accessed from the resulting ``SqlModelVersion`` object. """ _validate_model_name(name) _validate_model_version(version) query_options = cls._get_eager_model_version_query_options( ) if eager else [] conditions = [ SqlModelVersion.name == name, SqlModelVersion.version == version, SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL ] versions = session.query(SqlModelVersion).options( *query_options).filter(*conditions).all() if len(versions) == 0: raise MlflowException( 'Model Version (name={}, version={}) ' 'not found'.format(name, version), RESOURCE_DOES_NOT_EXIST) if len(versions) > 1: raise MlflowException( 'Expected only 1 model version with (name={}, version={}). ' 'Found {}.'.format(name, version, len(versions)), INVALID_STATE) return versions[0]
def create_model_version(self, name, source, run_id, tags=None): """ Create a new model version from given source and run ID. :param name: Registered model name. :param source: Source path where the MLflow model is stored. :param run_id: Run ID from MLflow tracking server that generated the model. :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag` instances associated with this model version. :return: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion` created in the backend. """ def next_version(sql_registered_model): if sql_registered_model.model_versions: return max( [mv.version for mv in sql_registered_model.model_versions]) + 1 else: return 1 _validate_model_name(name) for tag in tags or []: _validate_model_version_tag(tag.key, tag.value) with self.ManagedSessionMaker() as session: creation_time = now() for attempt in range(self.CREATE_MODEL_VERSION_RETRIES): try: sql_registered_model = self._get_registered_model( session, name) sql_registered_model.last_updated_time = creation_time version = next_version(sql_registered_model) model_version = SqlModelVersion( name=name, version=version, creation_time=creation_time, last_updated_time=creation_time, source=source, run_id=run_id) tags_dict = {} for tag in tags or []: tags_dict[tag.key] = tag.value model_version.model_version_tags = [ SqlModelVersionTag(key=key, value=value) for key, value in tags_dict.items() ] self._save_to_db(session, [sql_registered_model, model_version]) session.flush() return model_version.to_mlflow_entity() except sqlalchemy.exc.IntegrityError: more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1 _logger.info( 'Model Version creation error (name=%s) Retrying %s more time%s.', name, str(more_retries), 's' if more_retries > 1 else '') raise MlflowException( 'Model Version creation error (name={}). Giving up after ' '{} attempts.'.format(name, self.CREATE_MODEL_VERSION_RETRIES))
def set_registered_model_tag(self, name, tag): """ Set a tag for the registered model. :param name: Registered model name. :param tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log. :return: None """ _validate_model_name(name) _validate_registered_model_tag(tag.key, tag.value) with self.ManagedSessionMaker() as session: # check if registered model exists self._get_registered_model(session, name) session.merge(SqlRegisteredModelTag(name=name, key=tag.key, value=tag.value))
def delete_registered_model_tag(self, name, key): """ Delete a tag associated with the registered model. :param name: Registered model name. :param key: Registered model tag key. :return: None """ _validate_model_name(name) _validate_tag_name(key) with self.ManagedSessionMaker() as session: # check if registered model exists self._get_registered_model(session, name) existing_tag = self._get_registered_model_tag(session, name, key) if existing_tag is not None: session.delete(existing_tag)
def _get_sql_model_version(cls, session, name, version, eager=False): """ :param eager: If ``True``, eagerly loads the model version's tags. If ``False``, these attributes are not eagerly loaded and will be loaded when their corresponding object properties are accessed from the resulting ``SqlModelVersion`` object. """ _validate_model_name(name) _validate_model_version(version) query_options = cls._get_eager_model_version_query_options() if eager else [] conditions = [ SqlModelVersion.name == name, SqlModelVersion.version == version, SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL, ] return cls._get_model_version_from_db(session, name, version, conditions, query_options)
def delete_model_version_tag(self, name, version, key): """ Delete a tag associated with the model version. :param name: Registered model name. :param version: Registered model version. :param key: Tag key. :return: None """ _validate_model_name(name) _validate_model_version(version) _validate_tag_name(key) with self.ManagedSessionMaker() as session: # check if model version exists self._get_sql_model_version(session, name, version) existing_tag = self._get_model_version_tag(session, name, version, key) if existing_tag is not None: session.delete(existing_tag)
def set_model_version_tag(self, name, version, tag): """ Set a tag for the model version. :param name: Registered model name. :param version: Registered model version. :param tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log. :return: None """ _validate_model_name(name) _validate_model_version(version) _validate_model_version_tag(tag.key, tag.value) with self.ManagedSessionMaker() as session: # check if model version exists self._get_sql_model_version(session, name, version) session.merge( SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value) )