Beispiel #1
0
 def _print_description_and_log_tags(self):
     _logger.info(
         "=== Launched MLflow run as Databricks job run with ID %s."
         " Getting run status page URL... ===", self._databricks_run_id)
     run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
     jobs_page_url = run_info["run_page_url"]
     _logger.info("=== Check the run's status at %s ===", jobs_page_url)
     host_creds = databricks_utils.get_databricks_host_creds(
         self._job_runner.databricks_profile_uri)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_RUN_URL,
                                     jobs_page_url)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID,
                                     self._databricks_run_id)
     tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                     MLFLOW_DATABRICKS_WEBAPP_URL,
                                     host_creds.host)
     job_id = run_info.get('job_id')
     # In some releases of Databricks we do not return the job ID. We start including it in DB
     # releases 2.80 and above.
     if job_id is not None:
         tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                         MLFLOW_DATABRICKS_SHELL_JOB_ID,
                                         job_id)
Beispiel #2
0
 def _databricks_api_request(self, endpoint, method, **kwargs):
     host_creds = databricks_utils.get_databricks_host_creds(
         self.databricks_profile_uri)
     return rest_utils.http_request_safe(host_creds=host_creds,
                                         endpoint=endpoint,
                                         method=method,
                                         **kwargs)
Beispiel #3
0
def test_databricks_params_throws_errors(ProfileConfigProvider):
    # No hostname
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = \
        DatabricksConfig(None, "user", "pass", None, insecure=True)
    ProfileConfigProvider.return_value = mock_provider
    with pytest.raises(Exception):
        databricks_utils.get_databricks_host_creds()

    # No authentication
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = \
        DatabricksConfig("host", None, None, None, insecure=True)
    ProfileConfigProvider.return_value = mock_provider
    with pytest.raises(Exception):
        databricks_utils.get_databricks_host_creds()
Beispiel #4
0
 def _dbfs_path_exists(self, dbfs_path):
     """
     Return True if the passed-in path exists in DBFS for the workspace corresponding to the
     default Databricks CLI profile. The path is expected to be a relative path to the DBFS root
     directory, e.g. 'path/to/file'.
     """
     host_creds = databricks_utils.get_databricks_host_creds(
         self.databricks_profile_uri)
     response = rest_utils.http_request(host_creds=host_creds,
                                        endpoint="/api/2.0/dbfs/get-status",
                                        method="GET",
                                        json={"path": "/%s" % dbfs_path})
     try:
         json_response_obj = json.loads(response.text)
     except Exception:  # pylint: disable=broad-except
         raise MlflowException(
             "API request to check existence of file at DBFS path %s failed with status code "
             "%s. Response body: %s" %
             (dbfs_path, response.status_code, response.text))
     # If request fails with a RESOURCE_DOES_NOT_EXIST error, the file does not exist on DBFS
     error_code_field = "error_code"
     if error_code_field in json_response_obj:
         if json_response_obj[
                 error_code_field] == "RESOURCE_DOES_NOT_EXIST":
             return False
         raise ExecutionException(
             "Got unexpected error response when checking whether file %s "
             "exists in DBFS: %s" % (dbfs_path, json_response_obj))
     return True
Beispiel #5
0
def test_databricks_params_user_password(get_config):
    get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=False)
    params = databricks_utils.get_databricks_host_creds()
    assert params.host == 'host'
    assert params.username == 'user'
    assert params.password == 'pass'
Beispiel #6
0
def test_databricks_params_token(get_config):
    get_config.return_value = \
        DatabricksConfig("host", None, None, "mytoken", insecure=False)
    params = databricks_utils.get_databricks_host_creds()
    assert params.host == 'host'
    assert params.token == 'mytoken'
    assert not params.ignore_tls_verification
Beispiel #7
0
def test_databricks_params_custom_profile(ProfileConfigProvider):
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=True)
    ProfileConfigProvider.return_value = mock_provider
    params = databricks_utils.get_databricks_host_creds(
        construct_db_uri_from_profile("profile"))
    assert params.ignore_tls_verification
    ProfileConfigProvider.assert_called_with("profile")
Beispiel #8
0
def test_databricks_registry_profile(ProfileConfigProvider):
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = None
    ProfileConfigProvider.return_value = mock_provider
    mock_dbutils = mock.MagicMock()
    mock_dbutils.secrets.get.return_value = 'random'
    with mock.patch("mlflow.utils.databricks_utils._get_dbutils",
                    return_value=mock_dbutils):
        params = databricks_utils.get_databricks_host_creds(
            "databricks://profile/prefix")
        mock_dbutils.secrets.get.assert_any_call(key='prefix-host',
                                                 scope='profile')
        mock_dbutils.secrets.get.assert_any_call(key='prefix-token',
                                                 scope='profile')
        assert params.host == 'random'
        assert params.token == 'random'
Beispiel #9
0
def _get_databricks_env_vars(tracking_uri):
    if not kiwi.utils.uri.is_databricks_uri(tracking_uri):
        return {}

    config = databricks_utils.get_databricks_host_creds(tracking_uri)
    # We set these via environment variables so that only the current profile is exposed, rather
    # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
    # part of ~/.databrickscfg into the container
    env_vars = {}
    env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks'
    env_vars['DATABRICKS_HOST'] = config.host
    if config.username:
        env_vars['DATABRICKS_USERNAME'] = config.username
    if config.password:
        env_vars['DATABRICKS_PASSWORD'] = config.password
    if config.token:
        env_vars['DATABRICKS_TOKEN'] = config.token
    if config.ignore_tls_verification:
        env_vars['DATABRICKS_INSECURE'] = str(config.ignore_tls_verification)
    return env_vars
Beispiel #10
0
def _get_databricks_rest_store(store_uri, **_):
    return DatabricksRestStore(lambda: get_databricks_host_creds(store_uri))
 def _call_endpoint(self, service, api, json_body):
     db_creds = get_databricks_host_creds(kiwi.tracking.get_tracking_uri())
     endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
     response_proto = api.Response()
     return call_endpoint(db_creds, endpoint, method, json_body,
                          response_proto)
Beispiel #12
0
def test_databricks_single_slash_in_uri_scheme_throws(get_config):
    get_config.return_value = None
    with pytest.raises(Exception):
        databricks_utils.get_databricks_host_creds("databricks:/profile/path")
Beispiel #13
0
def test_databricks_empty_uri(get_config):
    get_config.return_value = None
    with pytest.raises(Exception):
        databricks_utils.get_databricks_host_creds("")
Beispiel #14
0
def test_databricks_params_no_verify(get_config):
    get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=True)
    params = databricks_utils.get_databricks_host_creds()
    assert params.ignore_tls_verification