def create_model_version(self, name, source, run_id, tags=None):
        """
        Create a new model version from given source and run ID.

        :param name: Registered model name.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model.
        :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
                     instances associated with this model version.
        :return: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 created in the backend.
        """
        def next_version(sql_registered_model):
            if sql_registered_model.model_versions:
                return max(
                    [mv.version
                     for mv in sql_registered_model.model_versions]) + 1
            else:
                return 1

        _validate_model_name(name)
        for tag in tags or []:
            _validate_model_version_tag(tag.key, tag.value)
        with self.ManagedSessionMaker() as session:
            creation_time = now()
            for attempt in range(self.CREATE_MODEL_VERSION_RETRIES):
                try:
                    sql_registered_model = self._get_registered_model(
                        session, name)
                    sql_registered_model.last_updated_time = creation_time
                    version = next_version(sql_registered_model)
                    model_version = SqlModelVersion(
                        name=name,
                        version=version,
                        creation_time=creation_time,
                        last_updated_time=creation_time,
                        source=source,
                        run_id=run_id)
                    tags_dict = {}
                    for tag in tags or []:
                        tags_dict[tag.key] = tag.value
                    model_version.model_version_tags = [
                        SqlModelVersionTag(key=key, value=value)
                        for key, value in tags_dict.items()
                    ]
                    self._save_to_db(session,
                                     [sql_registered_model, model_version])
                    session.flush()
                    return model_version.to_mlflow_entity()
                except sqlalchemy.exc.IntegrityError:
                    more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1
                    _logger.info(
                        'Model Version creation error (name=%s) Retrying %s more time%s.',
                        name, str(more_retries),
                        's' if more_retries > 1 else '')
        raise MlflowException(
            'Model Version creation error (name={}). Giving up after '
            '{} attempts.'.format(name, self.CREATE_MODEL_VERSION_RETRIES))
Beispiel #2
0
    def set_model_version_tag(self, name, version, tag):
        """
        Set a tag for the model version.

        :param name: Registered model name.
        :param version: Registered model version.
        :param tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
        :return: None
        """
        _validate_model_name(name)
        _validate_model_version(version)
        _validate_model_version_tag(tag.key, tag.value)
        with self.ManagedSessionMaker() as session:
            # check if model version exists
            self._get_sql_model_version(session, name, version)
            session.merge(
                SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value)
            )