예제 #1
0
    def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(
                message="DBFS URI must be of the form dbfs:/<path> or "
                + "dbfs://profile@databricks/<path>",
                error_code=INVALID_PARAMETER_VALUE,
            )

        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri))

        databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri)
        if databricks_profile_uri:
            hostcreds_from_uri = get_databricks_host_creds(databricks_profile_uri)
            self.get_host_creds = lambda: hostcreds_from_uri
        else:
            self.get_host_creds = _get_host_creds_from_default_store()
예제 #2
0
 def _print_description_and_log_tags(self):
     eprint("=== Launched MLflow run as Databricks job run with ID %s. Getting run status "
            "page URL... ===" % self._databricks_run_id)
     run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
     jobs_page_url = run_info["run_page_url"]
     eprint("=== Check the run's status at %s ===" % jobs_page_url)
     host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
     job_id = run_info.get('job_id')
     # In some releases of Databricks we do not return the job ID. We start including it in DB
     # releases 2.80 and above.
     if job_id is not None:
         tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                         MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)
def _get_docker_command(image, active_run):
    docker_path = "docker"
    cmd = [docker_path, "run", "--rm"]
    env_vars = _get_run_env_vars(run_id=active_run.info.run_id,
                                 experiment_id=active_run.info.experiment_id)
    tracking_uri = tracking.get_tracking_uri()
    local_path, container_tracking_uri = _get_local_uri_or_none(tracking_uri)
    artifact_uri_local_path = get_local_path_or_none(
        active_run.info.artifact_uri)
    if local_path is not None:
        cmd += ["-v", "%s:%s" % (local_path, _MLFLOW_DOCKER_TRACKING_DIR_PATH)]
        env_vars[tracking._TRACKING_URI_ENV_VAR] = container_tracking_uri
    if artifact_uri_local_path is not None:
        container_path = artifact_uri_local_path
        if not os.path.isabs(container_path):
            container_path = os.path.join("/mlflow/projects/code/",
                                          artifact_uri_local_path)
            container_path = os.path.normpath(container_path)
        artifact_uri_local_abspath = os.path.abspath(artifact_uri_local_path)
        cmd += ["-v", "%s:%s" % (artifact_uri_local_abspath, container_path)]
    if tracking.utils._is_databricks_uri(tracking_uri):
        db_profile = mlflow.tracking.utils.get_db_profile_from_uri(
            tracking_uri)
        config = databricks_utils.get_databricks_host_creds(db_profile)
        # We set these via environment variables so that only the current profile is exposed, rather
        # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
        # part of ~/.databrickscfg into the container
        env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks'
        env_vars['DATABRICKS_HOST'] = config.host
        if config.username:
            env_vars['DATABRICKS_USERNAME'] = config.username
        if config.password:
            env_vars['DATABRICKS_PASSWORD'] = config.password
        if config.token:
            env_vars['DATABRICKS_TOKEN'] = config.token
        if config.ignore_tls_verification:
            env_vars['DATABRICKS_INSECURE'] = config.ignore_tls_verification

    for key, value in env_vars.items():
        cmd += ["-e", "{key}={value}".format(key=key, value=value)]
    cmd += [image.tags[0]]
    return cmd
예제 #4
0
def get_databricks_env_vars(tracking_uri):
    if not mlflow.utils.uri.is_databricks_uri(tracking_uri):
        return {}

    config = databricks_utils.get_databricks_host_creds(tracking_uri)
    # We set these via environment variables so that only the current profile is exposed, rather
    # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
    # part of ~/.databrickscfg into the container
    env_vars = {}
    env_vars[tracking._TRACKING_URI_ENV_VAR] = "databricks"
    env_vars["DATABRICKS_HOST"] = config.host
    if config.username:
        env_vars["DATABRICKS_USERNAME"] = config.username
    if config.password:
        env_vars["DATABRICKS_PASSWORD"] = config.password
    if config.token:
        env_vars["DATABRICKS_TOKEN"] = config.token
    if config.ignore_tls_verification:
        env_vars["DATABRICKS_INSECURE"] = str(config.ignore_tls_verification)
    return env_vars
예제 #5
0
파일: databricks.py 프로젝트: zge/mlflow
 def _dbfs_path_exists(self, dbfs_path):
     """
     Return True if the passed-in path exists in DBFS for the workspace corresponding to the
     default Databricks CLI profile. The path is expected to be a relative path to the DBFS root
     directory, e.g. 'path/to/file'.
     """
     host_creds = databricks_utils.get_databricks_host_creds(self.databricks_profile)
     response = rest_utils.http_request(
         host_creds=host_creds, endpoint="/api/2.0/dbfs/get-status", method="GET",
         json={"path": "/%s" % dbfs_path})
     try:
         json_response_obj = json.loads(response.text)
     except ValueError:
         raise MlflowException(
             "API request to check existence of file at DBFS path %s failed with status code "
             "%s. Response body: %s" % (dbfs_path, response.status_code, response.text))
     # If request fails with a RESOURCE_DOES_NOT_EXIST error, the file does not exist on DBFS
     error_code_field = "error_code"
     if error_code_field in json_response_obj:
         if json_response_obj[error_code_field] == "RESOURCE_DOES_NOT_EXIST":
             return False
         raise ExecutionException("Got unexpected error response when checking whether file %s "
                                  "exists in DBFS: %s" % (dbfs_path, json_response_obj))
     return True
def test_databricks_params_no_verify(get_config):
    get_config.return_value = DatabricksConfig("host", "user", "pass", None, insecure=True)
    params = databricks_utils.get_databricks_host_creds()
    assert params.ignore_tls_verification
def test_databricks_params_user_password(get_config):
    get_config.return_value = DatabricksConfig("host", "user", "pass", None, insecure=False)
    params = databricks_utils.get_databricks_host_creds()
    assert params.host == "host"
    assert params.username == "user"
    assert params.password == "pass"
def test_databricks_params_token(get_config):
    get_config.return_value = DatabricksConfig("host", None, None, "mytoken", insecure=False)
    params = databricks_utils.get_databricks_host_creds()
    assert params.host == "host"
    assert params.token == "mytoken"
    assert not params.ignore_tls_verification
예제 #9
0
def _get_databricks_rest_store(store_uri, **_):
    return DatabricksRestStore(lambda: get_databricks_host_creds(store_uri))
예제 #10
0
 def _call_endpoint(self, service, api, json_body):
     db_creds = get_databricks_host_creds(self.databricks_profile_uri)
     endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
     response_proto = api.Response()
     return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
예제 #11
0
 def _call_endpoint(self, json, endpoint):
     db_creds = get_databricks_host_creds(self.databricks_profile_uri)
     return http_request(host_creds=db_creds,
                         endpoint=endpoint,
                         method="GET",
                         params=json)
예제 #12
0
def test_databricks_single_slash_in_uri_scheme_throws(get_config):
    get_config.return_value = None
    with pytest.raises(MlflowException, match="URI is formatted incorrectly"):
        databricks_utils.get_databricks_host_creds("databricks:/profile:path")
예제 #13
0
def test_databricks_empty_uri(get_config):
    get_config.return_value = None
    with pytest.raises(MlflowException, match="Got malformed Databricks CLI profile"):
        databricks_utils.get_databricks_host_creds("")
예제 #14
0
 def _call_endpoint(self, service, api, json_body):
     db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri())
     db_creds = get_databricks_host_creds(db_profile)
     endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
     response_proto = api.Response()
     return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
def test_databricks_empty_uri(get_config):
    get_config.return_value = None
    with pytest.raises(Exception):
        databricks_utils.get_databricks_host_creds("")
def test_databricks_single_slash_in_uri_scheme_throws(get_config):
    get_config.return_value = None
    with pytest.raises(Exception):
        databricks_utils.get_databricks_host_creds("databricks:/profile:path")
예제 #17
0
def _get_databricks_rest_store(store_uri):
    profile = get_db_profile_from_uri(store_uri)
    return RestStore(lambda: get_databricks_host_creds(profile))
예제 #18
0
 def _check_auth_available(self):
     """
     Verifies that information for making API requests to Databricks is available to MLflow,
     raising an exception if not.
     """
     databricks_utils.get_databricks_host_creds(self.databricks_profile)
예제 #19
0
 def _databricks_api_request(self, endpoint, method, **kwargs):
     host_creds = databricks_utils.get_databricks_host_creds(self.databricks_profile_uri)
     return rest_utils.http_request_safe(
         host_creds=host_creds, endpoint=endpoint, method=method, **kwargs
     )