def get_databricks_host_creds(server_uri=None): """ Reads in configuration necessary to make HTTP requests to a Databricks server. This uses the Databricks CLI's ConfigProvider interface to load the DatabricksConfig object. If no Databricks CLI profile is found corresponding to the server URI, this function will attempt to retrieve these credentials from the Databricks Secret Manager. For that to work, the server URI will need to be of the following format: "databricks://profile/prefix". In the Databricks Secret Manager, we will query for a secret in the scope "<profile>" for secrets with keys of the form "<prefix>-host" and "<prefix>-token". Note that this prefix *cannot* be empty if trying to authenticate with this method. If found, those host credentials will be used. This method will throw an exception if sufficient auth cannot be found. :param server_uri: A URI that specifies the Databricks profile you want to use for making requests. :return: :py:class:`mlflow.rest_utils.MlflowHostCreds` which includes the hostname and authentication information necessary to talk to the Databricks server. """ profile, path = get_db_info_from_uri(server_uri) if not hasattr(provider, 'get_config'): _logger.warning( "Support for databricks-cli<0.8.0 is deprecated and will be removed" " in a future version.") config = provider.get_config_for_profile(profile) elif profile: config = provider.ProfileConfigProvider(profile).get_config() else: config = provider.get_config() # if a path is specified, that implies a Databricks tracking URI of the form: # databricks://profile-name/path-specifier if (not config or not config.host) and path: dbutils = _get_dbutils() if dbutils: # Prefix differentiates users and is provided as path information in the URI key_prefix = path host = dbutils.secrets.get(scope=profile, key=key_prefix + "-host") token = dbutils.secrets.get(scope=profile, key=key_prefix + "-token") if host and token: config = provider.DatabricksConfig.from_token(host=host, token=token, insecure=False) if not config or not config.host: _fail_malformed_databricks_auth(profile) insecure = hasattr(config, 'insecure') and config.insecure if config.username is not None and config.password is not None: return MlflowHostCreds(config.host, username=config.username, password=config.password, ignore_tls_verification=insecure) elif config.token: return MlflowHostCreds(config.host, token=config.token, ignore_tls_verification=insecure) _fail_malformed_databricks_auth(profile)
def test_429_retries(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) class MockedResponse(object): def __init__(self, status_code): self.status_code = status_code self.text = "mocked text" request.side_effect = [MockedResponse(x) for x in (429, 200)] assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=0).status_code == 429 request.side_effect = [MockedResponse(x) for x in (429, 200)] assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=1).status_code == 200 request.side_effect = [MockedResponse(x) for x in (429, 429, 200)] assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=1).status_code == 429 request.side_effect = [MockedResponse(x) for x in (429, 429, 200)] assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=2).status_code == 200 request.side_effect = [MockedResponse(x) for x in (429, 429, 200)] assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=3).status_code == 200 # Test that any non 429 code is returned request.side_effect = [MockedResponse(x) for x in (429, 404, 429, 200)] assert http_request(host_only, '/my/endpoint').status_code == 404 # Test that retries work as expected request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)] with pytest.raises(MlflowException, match="failed to return code 200"): http_request(host_only, '/my/endpoint', retries=1) request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)] assert http_request(host_only, '/my/endpoint', retries=2).status_code == 200
def test_ignore_tls_verification_not_server_cert_path(): with pytest.raises(MlflowException): MlflowHostCreds( "http://my-host", ignore_tls_verification=True, server_cert_path='/some/path', )
def test_malformed_json_error_response(response_mock): with mock.patch('requests.request') as request_mock: host_only = MlflowHostCreds("http://my-host") request_mock.return_value = response_mock response_proto = GetRun.Response() with pytest.raises(MlflowException): call_endpoint(host_only, '/my/endpoint', 'GET', "", response_proto)
def test_init_validation_and_cleaning(self): with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store') \ as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds('http://host') repo = get_artifact_repository('dbfs:/test/') assert repo.artifact_uri == 'dbfs:/test' with pytest.raises(MlflowException): DbfsRestArtifactRepository('s3://test')
def test_failed_http_request_custom_handler(self, request): response = mock.MagicMock response.status_code = 404 response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}' request.return_value = response store = CustomErrorHandlingRestStore(lambda: MlflowHostCreds('https://hello')) with pytest.raises(MyCoolException): store.list_experiments()
def test_failed_http_request(self, request): response = mock.MagicMock response.status_code = 404 response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}' request.return_value = response store = RestStore(lambda: MlflowHostCreds('https://hello')) with pytest.raises(MlflowException) as cm: store.list_experiments() assert "RESOURCE_DOES_NOT_EXIST: No experiment" in str(cm.value)
def test_http_request_hostonly(request): host_only = MlflowHostCreds("http://my-host") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify=True, headers=_DEFAULT_HEADERS, )
def test_well_formed_json_error_response(): with mock.patch('requests.request') as request_mock: host_only = MlflowHostCreds("http://my-host") response_mock = mock.MagicMock() response_mock.status_code = 400 response_mock.text = "{}" # well-formed JSON error response request_mock.return_value = response_mock response_proto = GetRun.Response() with pytest.raises(RestException): call_endpoint(host_only, '/my/endpoint', 'GET', "", response_proto)
def test_http_request_with_insecure(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify=False, headers=_DEFAULT_HEADERS, )
def test_http_request_cleans_hostname(request): # Add a trailing slash, should be removed. host_only = MlflowHostCreds("http://my-host/") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify=True, headers=_DEFAULT_HEADERS, )
def test_http_request_with_basic_auth(request): host_only = MlflowHostCreds("http://my-host", username='******', password='******') response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') headers = dict(_DEFAULT_HEADERS) headers['Authorization'] = 'Basic dXNlcjpwYXNz' request.assert_called_with( url='http://my-host/my/endpoint', verify=True, headers=headers, )
def test_http_request_with_token(request): host_only = MlflowHostCreds("http://my-host", token='my-token') response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') headers = dict(_DEFAULT_HEADERS) headers['Authorization'] = 'Bearer my-token' request.assert_called_with( url='http://my-host/my/endpoint', verify=True, headers=headers, )
def test_response_with_unknown_fields(self, request): experiment_json = { "experiment_id": "1", "name": "My experiment", "artifact_location": "foo", "lifecycle_stage": "deleted", "OMG_WHAT_IS_THIS_FIELD": "Hooly cow", } response = mock.MagicMock response.status_code = 200 experiments = {"experiments": [experiment_json]} response.text = json.dumps(experiments) request.return_value = response store = RestStore(lambda: MlflowHostCreds('https://hello')) experiments = store.list_experiments() assert len(experiments) == 1 assert experiments[0].name == 'My experiment'
def test_databricks_rest_store_get_experiment_by_name(self): creds = MlflowHostCreds('https://hello') store = DatabricksRestStore(lambda: creds) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: # Verify that Databricks REST client won't fall back to ListExperiments for 500-level # errors that are not ENDPOINT_NOT_FOUND def rate_limit_response_fn(*args, **kwargs): # pylint: disable=unused-argument raise MlflowException("Some internal error!", INTERNAL_ERROR) mock_http.side_effect = rate_limit_response_fn with pytest.raises(MlflowException) as exc_info: store.get_experiment_by_name("abc") assert exc_info.value.error_code == ErrorCode.Name(INTERNAL_ERROR) assert exc_info.value.message == "Some internal error!" expected_message0 = GetExperimentByName(experiment_name="abc") self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message0)) assert mock_http.call_count == 1
def test_http_request_wrapper(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request_safe(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify=False, headers=_DEFAULT_HEADERS, ) response.status_code = 400 response.text = "" request.return_value = response with pytest.raises(MlflowException, match="Response body"): http_request_safe(host_only, '/my/endpoint') response.text = \ '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}' request.return_value = response with pytest.raises(RestException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported"): http_request_safe(host_only, '/my/endpoint')
def test_successful_http_request(self, request): def mock_request(**kwargs): # Filter out None arguments kwargs = dict((k, v) for k, v in six.iteritems(kwargs) if v is not None) assert kwargs == { 'method': 'GET', 'params': {'view_type': 'ACTIVE_ONLY'}, 'url': 'https://hello/api/2.0/mlflow/experiments/list', 'headers': _DEFAULT_HEADERS, 'verify': True, } response = mock.MagicMock response.status_code = 200 response.text = '{"experiments": [{"name": "Exp!", "lifecycle_stage": "active"}]}' return response request.side_effect = mock_request store = RestStore(lambda: MlflowHostCreds('https://hello')) experiments = store.list_experiments() assert experiments[0].name == "Exp!"
def host_creds_mock(): with mock.patch('mlflow.store.artifact.dbfs_artifact_repo._get_host_creds_from_default_store') \ as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds('http://host') yield
def test_get_host_creds_from_default_store_rest_store(): with mock.patch('mlflow.tracking._tracking_service.utils._get_store') as get_store_mock: get_store_mock.return_value = RestStore(lambda: MlflowHostCreds('http://host')) assert isinstance(_get_host_creds_from_default_store()(), MlflowHostCreds)
def dbfs_artifact_repo(): with mock.patch('mlflow.store.artifact.dbfs_artifact_repo._get_host_creds_from_default_store') \ as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds('http://host') return get_artifact_repository('dbfs:/test/')
def setUp(self): self.creds = MlflowHostCreds('https://hello') self.store = RestStore(lambda: self.creds)
def test_requestor(self, request): response = mock.MagicMock response.status_code = 200 response.text = '{}' request.return_value = response creds = MlflowHostCreds('https://hello') store = RestStore(lambda: creds) user_name = "mock user" source_name = "rest test" source_name_patch = mock.patch( "mlflow.tracking.context.default_context._get_source_name", return_value=source_name ) source_type_patch = mock.patch( "mlflow.tracking.context.default_context._get_source_type", return_value=SourceType.LOCAL ) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http, \ mock.patch('mlflow.tracking._tracking_service.utils._get_store', return_value=store), \ mock.patch('mlflow.tracking.context.default_context._get_user', return_value=user_name), \ mock.patch('time.time', return_value=13579), \ source_name_patch, source_type_patch: with kiwi.start_run(experiment_id="43"): cr_body = message_to_json(CreateRun(experiment_id="43", user_id=user_name, start_time=13579000, tags=[ProtoRunTag(key='mlflow.source.name', value=source_name), ProtoRunTag(key='mlflow.source.type', value='LOCAL'), ProtoRunTag(key='mlflow.user', value=user_name)])) expected_kwargs = self._args(creds, "runs/create", "POST", cr_body) assert mock_http.call_count == 1 actual_kwargs = mock_http.call_args[1] # Test the passed tag values separately from the rest of the request # Tag order is inconsistent on Python 2 and 3, but the order does not matter expected_tags = expected_kwargs['json'].pop('tags') actual_tags = actual_kwargs['json'].pop('tags') assert ( sorted(expected_tags, key=lambda t: t['key']) == sorted(actual_tags, key=lambda t: t['key']) ) assert expected_kwargs == actual_kwargs with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.log_param("some_uuid", Param("k1", "v1")) body = message_to_json(LogParam( run_uuid="some_uuid", run_id="some_uuid", key="k1", value="v1")) self._verify_requests(mock_http, creds, "runs/log-parameter", "POST", body) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.set_experiment_tag("some_id", ExperimentTag("t1", "abcd"*1000)) body = message_to_json(SetExperimentTag( experiment_id="some_id", key="t1", value="abcd"*1000)) self._verify_requests(mock_http, creds, "experiments/set-experiment-tag", "POST", body) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.set_tag("some_uuid", RunTag("t1", "abcd"*1000)) body = message_to_json(SetTag( run_uuid="some_uuid", run_id="some_uuid", key="t1", value="abcd"*1000)) self._verify_requests(mock_http, creds, "runs/set-tag", "POST", body) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.delete_tag("some_uuid", "t1") body = message_to_json(DeleteTag(run_id="some_uuid", key="t1")) self._verify_requests(mock_http, creds, "runs/delete-tag", "POST", body) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.log_metric("u2", Metric("m1", 0.87, 12345, 3)) body = message_to_json(LogMetric( run_uuid="u2", run_id="u2", key="m1", value=0.87, timestamp=12345, step=3)) self._verify_requests(mock_http, creds, "runs/log-metric", "POST", body) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: metrics = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, -1), Metric("m3", 0.58, 12345, 2)] params = [Param("p1", "p1val"), Param("p2", "p2val")] tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")] store.log_batch(run_id="u2", metrics=metrics, params=params, tags=tags) metric_protos = [metric.to_proto() for metric in metrics] param_protos = [param.to_proto() for param in params] tag_protos = [tag.to_proto() for tag in tags] body = message_to_json(LogBatch(run_id="u2", metrics=metric_protos, params=param_protos, tags=tag_protos)) self._verify_requests(mock_http, creds, "runs/log-batch", "POST", body) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.delete_run("u25") self._verify_requests(mock_http, creds, "runs/delete", "POST", message_to_json(DeleteRun(run_id="u25"))) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.restore_run("u76") self._verify_requests(mock_http, creds, "runs/restore", "POST", message_to_json(RestoreRun(run_id="u76"))) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.delete_experiment("0") self._verify_requests(mock_http, creds, "experiments/delete", "POST", message_to_json(DeleteExperiment(experiment_id="0"))) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: store.restore_experiment("0") self._verify_requests(mock_http, creds, "experiments/restore", "POST", message_to_json(RestoreExperiment(experiment_id="0"))) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: response = mock.MagicMock response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}' mock_http.return_value = response result = store.search_runs(["0", "1"], "params.p1 = 'a'", ViewType.ACTIVE_ONLY, max_results=10, order_by=["a"], page_token="12345abcde") expected_message = SearchRuns(experiment_ids=["0", "1"], filter="params.p1 = 'a'", run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY), max_results=10, order_by=["a"], page_token="12345abcde") self._verify_requests(mock_http, creds, "runs/search", "POST", message_to_json(expected_message)) assert result.token == "67890fghij" with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: run_id = "run_id" m = Model(artifact_path="model/path", run_id="run_id", flavors={"tf": "flavor body"}) result = store.record_logged_model("run_id", m) expected_message = LogModel(run_id=run_id, model_json=m.to_json()) self._verify_requests(mock_http, creds, "runs/log-model", "POST", message_to_json(expected_message))
def test_get_experiment_by_name(self, store_class): creds = MlflowHostCreds('https://hello') store = store_class(lambda: creds) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: response = mock.MagicMock response.status_code = 200 experiment = Experiment( experiment_id="123", name="abc", artifact_location="/abc", lifecycle_stage=LifecycleStage.ACTIVE) response.text = json.dumps({ "experiment": json.loads(message_to_json(experiment.to_proto()))}) mock_http.return_value = response result = store.get_experiment_by_name("abc") expected_message0 = GetExperimentByName(experiment_name="abc") self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message0)) assert result.experiment_id == experiment.experiment_id assert result.name == experiment.name assert result.artifact_location == experiment.artifact_location assert result.lifecycle_stage == experiment.lifecycle_stage # Test GetExperimentByName against nonexistent experiment mock_http.reset_mock() nonexistent_exp_response = mock.MagicMock nonexistent_exp_response.status_code = 404 nonexistent_exp_response.text =\ MlflowException("Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST).serialize_as_json() mock_http.return_value = nonexistent_exp_response assert store.get_experiment_by_name("nonexistent-experiment") is None expected_message1 = GetExperimentByName(experiment_name="nonexistent-experiment") self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message1)) assert mock_http.call_count == 1 # Test REST client behavior against a mocked old server, which has handler for # ListExperiments but not GetExperimentByName mock_http.reset_mock() list_exp_response = mock.MagicMock list_exp_response.text = json.dumps({ "experiments": [json.loads(message_to_json(experiment.to_proto()))]}) list_exp_response.status_code = 200 def response_fn(*args, **kwargs): # pylint: disable=unused-argument if kwargs.get('endpoint') == "/api/2.0/mlflow/experiments/get-by-name": raise MlflowException("GetExperimentByName is not implemented", ENDPOINT_NOT_FOUND) else: return list_exp_response mock_http.side_effect = response_fn result = store.get_experiment_by_name("abc") expected_message2 = ListExperiments(view_type=ViewType.ALL) self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message0)) self._verify_requests(mock_http, creds, "experiments/list", "GET", message_to_json(expected_message2)) assert result.experiment_id == experiment.experiment_id assert result.name == experiment.name assert result.artifact_location == experiment.artifact_location assert result.lifecycle_stage == experiment.lifecycle_stage # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting # rate limits) mock_http.reset_mock() def rate_limit_response_fn(*args, **kwargs): # pylint: disable=unused-argument raise MlflowException("Hit rate limit on GetExperimentByName", REQUEST_LIMIT_EXCEEDED) mock_http.side_effect = rate_limit_response_fn with pytest.raises(MlflowException) as exc_info: store.get_experiment_by_name("imspamming") assert exc_info.value.error_code == ErrorCode.Name(REQUEST_LIMIT_EXCEEDED) assert mock_http.call_count == 1