def __init__(self, artifact_uri): if not is_valid_dbfs_uri(artifact_uri): raise MlflowException( message="DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>", error_code=INVALID_PARAMETER_VALUE, ) # The dbfs:/ path ultimately used for artifact operations should not contain the # Databricks profile info, so strip it before setting ``artifact_uri``. super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri)) databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri) if databricks_profile_uri: hostcreds_from_uri = get_databricks_host_creds(databricks_profile_uri) self.get_host_creds = lambda: hostcreds_from_uri else: self.get_host_creds = _get_host_creds_from_default_store()
def _print_description_and_log_tags(self): eprint("=== Launched MLflow run as Databricks job run with ID %s. Getting run status " "page URL... ===" % self._databricks_run_id) run_info = self._job_runner.jobs_runs_get(self._databricks_run_id) jobs_page_url = run_info["run_page_url"] eprint("=== Check the run's status at %s ===" % jobs_page_url) host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile) tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_RUN_URL, jobs_page_url) tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id) tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host) job_id = run_info.get('job_id') # In some releases of Databricks we do not return the job ID. We start including it in DB # releases 2.80 and above. if job_id is not None: tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)
def _get_docker_command(image, active_run): docker_path = "docker" cmd = [docker_path, "run", "--rm"] env_vars = _get_run_env_vars(run_id=active_run.info.run_id, experiment_id=active_run.info.experiment_id) tracking_uri = tracking.get_tracking_uri() local_path, container_tracking_uri = _get_local_uri_or_none(tracking_uri) artifact_uri_local_path = get_local_path_or_none( active_run.info.artifact_uri) if local_path is not None: cmd += ["-v", "%s:%s" % (local_path, _MLFLOW_DOCKER_TRACKING_DIR_PATH)] env_vars[tracking._TRACKING_URI_ENV_VAR] = container_tracking_uri if artifact_uri_local_path is not None: container_path = artifact_uri_local_path if not os.path.isabs(container_path): container_path = os.path.join("/mlflow/projects/code/", artifact_uri_local_path) container_path = os.path.normpath(container_path) artifact_uri_local_abspath = os.path.abspath(artifact_uri_local_path) cmd += ["-v", "%s:%s" % (artifact_uri_local_abspath, container_path)] if tracking.utils._is_databricks_uri(tracking_uri): db_profile = mlflow.tracking.utils.get_db_profile_from_uri( tracking_uri) config = databricks_utils.get_databricks_host_creds(db_profile) # We set these via environment variables so that only the current profile is exposed, rather # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary # part of ~/.databrickscfg into the container env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks' env_vars['DATABRICKS_HOST'] = config.host if config.username: env_vars['DATABRICKS_USERNAME'] = config.username if config.password: env_vars['DATABRICKS_PASSWORD'] = config.password if config.token: env_vars['DATABRICKS_TOKEN'] = config.token if config.ignore_tls_verification: env_vars['DATABRICKS_INSECURE'] = config.ignore_tls_verification for key, value in env_vars.items(): cmd += ["-e", "{key}={value}".format(key=key, value=value)] cmd += [image.tags[0]] return cmd
def get_databricks_env_vars(tracking_uri): if not mlflow.utils.uri.is_databricks_uri(tracking_uri): return {} config = databricks_utils.get_databricks_host_creds(tracking_uri) # We set these via environment variables so that only the current profile is exposed, rather # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary # part of ~/.databrickscfg into the container env_vars = {} env_vars[tracking._TRACKING_URI_ENV_VAR] = "databricks" env_vars["DATABRICKS_HOST"] = config.host if config.username: env_vars["DATABRICKS_USERNAME"] = config.username if config.password: env_vars["DATABRICKS_PASSWORD"] = config.password if config.token: env_vars["DATABRICKS_TOKEN"] = config.token if config.ignore_tls_verification: env_vars["DATABRICKS_INSECURE"] = str(config.ignore_tls_verification) return env_vars
def _dbfs_path_exists(self, dbfs_path): """ Return True if the passed-in path exists in DBFS for the workspace corresponding to the default Databricks CLI profile. The path is expected to be a relative path to the DBFS root directory, e.g. 'path/to/file'. """ host_creds = databricks_utils.get_databricks_host_creds(self.databricks_profile) response = rest_utils.http_request( host_creds=host_creds, endpoint="/api/2.0/dbfs/get-status", method="GET", json={"path": "/%s" % dbfs_path}) try: json_response_obj = json.loads(response.text) except ValueError: raise MlflowException( "API request to check existence of file at DBFS path %s failed with status code " "%s. Response body: %s" % (dbfs_path, response.status_code, response.text)) # If request fails with a RESOURCE_DOES_NOT_EXIST error, the file does not exist on DBFS error_code_field = "error_code" if error_code_field in json_response_obj: if json_response_obj[error_code_field] == "RESOURCE_DOES_NOT_EXIST": return False raise ExecutionException("Got unexpected error response when checking whether file %s " "exists in DBFS: %s" % (dbfs_path, json_response_obj)) return True
def test_databricks_params_no_verify(get_config): get_config.return_value = DatabricksConfig("host", "user", "pass", None, insecure=True) params = databricks_utils.get_databricks_host_creds() assert params.ignore_tls_verification
def test_databricks_params_user_password(get_config): get_config.return_value = DatabricksConfig("host", "user", "pass", None, insecure=False) params = databricks_utils.get_databricks_host_creds() assert params.host == "host" assert params.username == "user" assert params.password == "pass"
def test_databricks_params_token(get_config): get_config.return_value = DatabricksConfig("host", None, None, "mytoken", insecure=False) params = databricks_utils.get_databricks_host_creds() assert params.host == "host" assert params.token == "mytoken" assert not params.ignore_tls_verification
def _get_databricks_rest_store(store_uri, **_): return DatabricksRestStore(lambda: get_databricks_host_creds(store_uri))
def _call_endpoint(self, service, api, json_body): db_creds = get_databricks_host_creds(self.databricks_profile_uri) endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
def _call_endpoint(self, json, endpoint): db_creds = get_databricks_host_creds(self.databricks_profile_uri) return http_request(host_creds=db_creds, endpoint=endpoint, method="GET", params=json)
def test_databricks_single_slash_in_uri_scheme_throws(get_config): get_config.return_value = None with pytest.raises(MlflowException, match="URI is formatted incorrectly"): databricks_utils.get_databricks_host_creds("databricks:/profile:path")
def test_databricks_empty_uri(get_config): get_config.return_value = None with pytest.raises(MlflowException, match="Got malformed Databricks CLI profile"): databricks_utils.get_databricks_host_creds("")
def _call_endpoint(self, service, api, json_body): db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri()) db_creds = get_databricks_host_creds(db_profile) endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
def test_databricks_empty_uri(get_config): get_config.return_value = None with pytest.raises(Exception): databricks_utils.get_databricks_host_creds("")
def test_databricks_single_slash_in_uri_scheme_throws(get_config): get_config.return_value = None with pytest.raises(Exception): databricks_utils.get_databricks_host_creds("databricks:/profile:path")
def _get_databricks_rest_store(store_uri): profile = get_db_profile_from_uri(store_uri) return RestStore(lambda: get_databricks_host_creds(profile))
def _check_auth_available(self): """ Verifies that information for making API requests to Databricks is available to MLflow, raising an exception if not. """ databricks_utils.get_databricks_host_creds(self.databricks_profile)
def _databricks_api_request(self, endpoint, method, **kwargs): host_creds = databricks_utils.get_databricks_host_creds(self.databricks_profile_uri) return rest_utils.http_request_safe( host_creds=host_creds, endpoint=endpoint, method=method, **kwargs )