示例#1
0
    def create_model_version(
        self,
        name,
        source,
        run_id,
        tags=None,
        run_link=None,
        description=None,
        await_creation_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
    ):
        """
        Create a new model version from given source (artifact URI).

        :param name: Name for the containing registered model.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model
        :param tags: A dictionary of key-value pairs that are converted into
                     :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects.
        :param run_link: Link to the run from an MLflow tracking server that generated this model.
        :param description: Description of the version.
        :param await_creation_for: Number of seconds to wait for the model version to finish being
                                    created and is in ``READY`` status. By default, the function
                                    waits for five minutes. Specify 0 or None to skip waiting.
        :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
                 backend.
        """
        tracking_uri = self._tracking_client.tracking_uri
        if not run_link and is_databricks_uri(
                tracking_uri) and tracking_uri != self._registry_uri:
            run_link = self._get_run_link(tracking_uri, run_id)
        new_source = source
        if is_databricks_uri(
                self._registry_uri) and tracking_uri != self._registry_uri:
            # Print out some info for user since the copy may take a while for large models.
            eprint(
                "=== Copying model files from the source location to the model"
                + " registry workspace ===")
            new_source = _upload_artifacts_to_databricks(
                source, run_id, tracking_uri, self._registry_uri)
            # NOTE: we can't easily delete the target temp location due to the async nature
            # of the model version creation - printing to let the user know.
            eprint(
                "=== Source model files were copied to %s" % new_source +
                " in the model registry workspace. You may want to delete the files once the"
                +
                " model version is in 'READY' status. You can also find this location in the"
                + " `source` field of the created model version. ===")
        return self._get_registry_client().create_model_version(
            name=name,
            source=new_source,
            run_id=run_id,
            tags=tags,
            run_link=run_link,
            description=description,
            await_creation_for=await_creation_for,
        )
示例#2
0
文件: __init__.py 项目: nlml/mlflow
def _get_docker_tracking_cmd_and_envs(tracking_uri):
    cmds = []
    env_vars = dict()

    local_path, container_tracking_uri = _get_local_uri_or_none(tracking_uri)
    if local_path is not None:
        cmds = ["-v", "%s:%s" % (local_path, _MLFLOW_DOCKER_TRACKING_DIR_PATH)]
        env_vars[tracking._TRACKING_URI_ENV_VAR] = container_tracking_uri
    if is_databricks_uri(tracking_uri):
        db_profile = get_db_profile_from_uri(tracking_uri)
        config = databricks_utils.get_databricks_host_creds(db_profile)
        # We set these via environment variables so that only the current profile is exposed, rather
        # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
        # part of ~/.databrickscfg into the container
        env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks'
        env_vars['DATABRICKS_HOST'] = config.host
        if config.username:
            env_vars['DATABRICKS_USERNAME'] = config.username
        if config.password:
            env_vars['DATABRICKS_PASSWORD'] = config.password
        if config.token:
            env_vars['DATABRICKS_TOKEN'] = config.token
        if config.ignore_tls_verification:
            env_vars['DATABRICKS_INSECURE'] = config.ignore_tls_verification
    return cmds, env_vars
示例#3
0
    def create_model_version(self,
                             name,
                             source,
                             run_id,
                             tags=None,
                             run_link=None):
        """
        Create a new model version from given source (artifact URI).

        :param name: Name for the containing registered model.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model
        :param tags: A dictionary of key-value pairs that are converted into
                     :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects.
        :param run_link: Link to the run from an MLflow tracking server that generated this model.
        :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
                 backend.
        """
        tracking_uri = self._tracking_client.tracking_uri
        if not run_link and is_databricks_uri(
                tracking_uri) and tracking_uri != self._registry_uri:
            run_link = self._get_run_link(tracking_uri, run_id)
        new_source = source
        if is_databricks_uri(
                self._registry_uri) and tracking_uri != self._registry_uri:
            # Print out some info for user since the copy may take a while for large models.
            _logger.info(
                "=== Copying model files from the source location to the model "
                + " registry workspace ===")
            new_source = _upload_artifacts_to_databricks(
                source, run_id, tracking_uri, self._registry_uri)
            # NOTE: we can't easily delete the target temp location due to the async nature
            # of the model version creation - printing to let the user know.
            _logger.info(
                """
                === Source model files were copied to %s
                    in the model registry workspace. You may want to delete the files once the
                    model version is in 'READY' status. You can also find this location in the
                    `source` field of the created model version. ===
                """, new_source)
        return self._get_registry_client().create_model_version(
            name=name,
            source=new_source,
            run_id=run_id,
            tags=tags,
            run_link=run_link)
示例#4
0
    def create_model_version(self,
                             name,
                             source,
                             run_id,
                             tags=None,
                             run_link=None):
        """
        Create a new model version from given source or run ID.

        :param name: Name ID for containing registered model.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model
        :param tags: A dictionary of key-value pairs that are converted into
                     :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects.
        :param run_link: Link to the run from an MLflow tracking server that generated this model.
        :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
                 backend.
        """
        tracking_uri = self._tracking_client.tracking_uri
        # for Databricks backends, we support automatically populating the run link field
        if is_databricks_uri(
                tracking_uri
        ) and tracking_uri != self._registry_uri and not run_link:
            # if using the default Databricks tracking URI and in a notebook, we can automatically
            # figure out the run-link.
            if is_databricks_default_tracking_uri(
                    tracking_uri) and is_in_databricks_notebook():
                # use DBUtils to determine workspace information.
                workspace_host, workspace_id = get_workspace_info_from_dbutils(
                )
            else:
                # in this scenario, we're not able to automatically extract the workspace ID
                # to proceed, and users will need to pass in a databricks profile with the scheme:
                # databricks://scope/prefix and store the host and workspace-ID as a secret in the
                # Databricks Secret Manager with scope=<scope> and key=<prefix>-workspaceid.
                workspace_host, workspace_id = \
                    get_workspace_info_from_databricks_secrets(tracking_uri)
                if not workspace_id:
                    print(
                        "No workspace ID specified; if your Databricks workspaces share the same"
                        " host URL, you may want to specify the workspace ID (along with the host"
                        " information in the secret manager) for run lineage tracking. For more"
                        " details on how to specify this information in the secret manager,"
                        " please refer to the model registry documentation.")
            # retrieve experiment ID of the run for the URL
            experiment_id = self.get_run(run_id).info.experiment_id
            if workspace_host and run_id and experiment_id:
                run_link = construct_run_url(workspace_host, experiment_id,
                                             run_id, workspace_id)
        return self._get_registry_client().create_model_version(
            name=name,
            source=source,
            run_id=run_id,
            tags=tags,
            run_link=run_link)
示例#5
0
def test_uri_types():
    assert is_local_uri("mlruns")
    assert is_local_uri("./mlruns")
    assert is_local_uri("file:///foo/mlruns")
    assert is_local_uri("file:foo/mlruns")
    assert not is_local_uri("https://whatever")
    assert not is_local_uri("http://whatever")
    assert not is_local_uri("databricks")
    assert not is_local_uri("databricks:whatever")
    assert not is_local_uri("databricks://whatever")

    assert is_databricks_uri("databricks")
    assert is_databricks_uri("databricks:whatever")
    assert is_databricks_uri("databricks://whatever")
    assert not is_databricks_uri("mlruns")
    assert not is_databricks_uri("http://whatever")

    assert is_http_uri("http://whatever")
    assert is_http_uri("https://whatever")
    assert not is_http_uri("file://whatever")
    assert not is_http_uri("databricks://whatever")
    assert not is_http_uri("mlruns")
示例#6
0
def before_run_validations(tracking_uri, backend_config):
    """Validations to perform before running a project on Databricks."""
    if backend_config is None:
        raise ExecutionException("Backend spec must be provided when launching MLflow project "
                                 "runs on Databricks.")
    if not is_databricks_uri(tracking_uri) and \
            not is_http_uri(tracking_uri):
        raise ExecutionException(
            "When running on Databricks, the MLflow tracking URI must be of the form "
            "'databricks' or 'databricks://profile', or a remote HTTP URI accessible to both the "
            "current client and code running on Databricks. Got local tracking URI %s. "
            "Please specify a valid tracking URI via mlflow.set_tracking_uri or by setting the "
            "MLFLOW_TRACKING_URI environment variable." % tracking_uri)
示例#7
0
def before_run_validations(tracking_uri, backend_config):
    """Validations to perform before running a project on Databricks."""
    if backend_config is None:
        raise ExecutionException(
            "Backend spec must be provided when launching MLflow project "
            "runs on Databricks.")
    elif "existing_cluster_id" in backend_config:
        raise MlflowException(message=(
            "MLflow Project runs on Databricks must provide a *new cluster* specification."
            " Project execution against existing clusters is not currently supported. For more"
            " information, see https://mlflow.org/docs/latest/projects.html"
            "#run-an-mlflow-project-on-databricks"),
                              error_code=INVALID_PARAMETER_VALUE)
    if not is_databricks_uri(tracking_uri) and \
            not is_http_uri(tracking_uri):
        raise ExecutionException(
            "When running on Databricks, the MLflow tracking URI must be of the form "
            "'databricks' or 'databricks://profile', or a remote HTTP URI accessible to both the "
            "current client and code running on Databricks. Got local tracking URI %s. "
            "Please specify a valid tracking URI via mlflow.set_tracking_uri or by setting the "
            "MLFLOW_TRACKING_URI environment variable." % tracking_uri)
示例#8
0
def is_using_databricks_registry(uri):
    profile_uri = get_databricks_profile_uri_from_artifact_uri(
        uri) or mlflow.get_registry_uri()
    return is_databricks_uri(profile_uri)