def test_init_get_host_creds_with_databricks_profile_uri(self): databricks_host = 'https://something.databricks.com' default_host = 'http://host' with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store', return_value=lambda: MlflowHostCreds(default_host)), \ mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '.get_databricks_host_creds', return_value=MlflowHostCreds(databricks_host)): repo = DbfsRestArtifactRepository( 'dbfs://profile@databricks/test/') assert repo.artifact_uri == 'dbfs:/test/' creds = repo.get_host_creds() assert creds.host == databricks_host # no databricks_profile_uri given repo = DbfsRestArtifactRepository('dbfs:/test/') creds = repo.get_host_creds() assert creds.host == default_host
def test_http_request_wrapper(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) response = mock.MagicMock() response.status_code = 200 response.text = "{}" request.return_value = response http_request_safe(host_only, "/my/endpoint") request.assert_called_with( url="http://my-host/my/endpoint", verify=False, headers=_DEFAULT_HEADERS, ) response.text = "non json" request.return_value = response http_request_safe(host_only, "/my/endpoint") request.assert_called_with( url="http://my-host/my/endpoint", verify=False, headers=_DEFAULT_HEADERS, ) response.status_code = 400 response.text = "" request.return_value = response with pytest.raises(MlflowException, match="Response body"): http_request_safe(host_only, "/my/endpoint") response.text = ( '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}' ) request.return_value = response with pytest.raises( RestException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported"): http_request_safe(host_only, "/my/endpoint")
def test_databricks_paginate_list_experiments(self): creds = MlflowHostCreds("https://hello") store = DatabricksRestStore(lambda: creds) list_exp_responses = [] next_page_tokens = ["a", "b", None] for next_page_token in next_page_tokens: experiment = Experiment( experiment_id="123", name=str(next_page_token), artifact_location="/abc", lifecycle_stage=LifecycleStage.ACTIVE, ) list_exp_response = mock.MagicMock() list_exp_response.text = json.dumps({ "experiments": [json.loads(message_to_json(experiment.to_proto()))], "next_page_token": next_page_token, }) list_exp_response.status_code = 200 list_exp_responses.append(list_exp_response) with mock.patch("mlflow.utils.rest_utils.http_request", side_effect=list_exp_responses): for idx, experiments in enumerate( store._paginate_list_experiments(ViewType.ACTIVE_ONLY)): assert experiments[0].name == str(next_page_tokens[idx]) assert experiments.token == next_page_tokens[idx]
def test_ignore_tls_verification_not_server_cert_path(): with pytest.raises(MlflowException): MlflowHostCreds( "http://my-host", ignore_tls_verification=True, server_cert_path="/some/path", )
def test_get_host_creds_from_default_store_rest_store(): with mock.patch("mlflow.tracking._tracking_service.utils._get_store" ) as get_store_mock: get_store_mock.return_value = RestStore( lambda: MlflowHostCreds("http://host")) assert isinstance(_get_host_creds_from_default_store()(), MlflowHostCreds)
def test_429_retries(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) class MockedResponse(object): def __init__(self, status_code): self.status_code = status_code self.text = "mocked text" request.side_effect = [MockedResponse(x) for x in (429, 200)] assert http_request(host_only, "/my/endpoint", max_rate_limit_interval=0).status_code == 429 request.side_effect = [MockedResponse(x) for x in (429, 200)] assert http_request(host_only, "/my/endpoint", max_rate_limit_interval=1).status_code == 200 request.side_effect = [MockedResponse(x) for x in (429, 429, 200)] assert http_request(host_only, "/my/endpoint", max_rate_limit_interval=1).status_code == 429 request.side_effect = [MockedResponse(x) for x in (429, 429, 200)] assert http_request(host_only, "/my/endpoint", max_rate_limit_interval=2).status_code == 200 request.side_effect = [MockedResponse(x) for x in (429, 429, 200)] assert http_request(host_only, "/my/endpoint", max_rate_limit_interval=3).status_code == 200 # Test that any non 429 code is returned request.side_effect = [MockedResponse(x) for x in (429, 404, 429, 200)] assert http_request(host_only, "/my/endpoint").status_code == 404 # Test that retries work as expected request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)] with pytest.raises(MlflowException, match="failed to return code 200"): http_request(host_only, "/my/endpoint", retries=1) request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)] assert http_request(host_only, "/my/endpoint", retries=2).status_code == 200
def test_default_host_creds(): artifact_uri = "https://test.com" username = "******" password = "******" token = "token" ignore_tls_verification = False client_cert_path = "client_cert_path" server_cert_path = "server_cert_path" expected_host_creds = MlflowHostCreds( host=artifact_uri, username=username, password=password, token=token, ignore_tls_verification=ignore_tls_verification, client_cert_path=client_cert_path, server_cert_path=server_cert_path, ) repo = HttpArtifactRepository(artifact_uri) with mock.patch.dict( "mlflow.tracking._tracking_service.utils.os.environ", { _TRACKING_USERNAME_ENV_VAR: username, _TRACKING_PASSWORD_ENV_VAR: password, _TRACKING_TOKEN_ENV_VAR: token, _TRACKING_INSECURE_TLS_ENV_VAR: str(ignore_tls_verification), _TRACKING_CLIENT_CERT_PATH_ENV_VAR: client_cert_path, _TRACKING_SERVER_CERT_PATH_ENV_VAR: server_cert_path, }, ): assert repo._host_creds == expected_host_creds
def test_http_request_request_headers_user_agent(request): """This test requires the package in tests/resources/mlflow-test-plugin to be installed""" from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider # The test plugin's request header provider always returns False from in_context to avoid # polluting request headers in developers' environments. The following mock overrides this to # perform the integration test. with mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True), mock.patch.object( PluginRequestHeaderProvider, "request_headers", return_value={_USER_AGENT: "test_user_agent"}, ): host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") expected_headers = { _USER_AGENT: "{} {}".format( DefaultRequestHeaderProvider().request_headers()[_USER_AGENT], "test_user_agent") } response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint", "GET") request.assert_called_with( "GET", "http://my-host/my/endpoint", verify="/some/path", headers=expected_headers, timeout=120, )
def test_databricks_rest_store_get_experiment_by_name(self): creds = MlflowHostCreds("https://hello") store = DatabricksRestStore(lambda: creds) with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http: # Verify that Databricks REST client won't fall back to ListExperiments for 500-level # errors that are not ENDPOINT_NOT_FOUND def rate_limit_response_fn(*args, **kwargs): # pylint: disable=unused-argument raise MlflowException("Some internal error!", INTERNAL_ERROR) mock_http.side_effect = rate_limit_response_fn with pytest.raises(MlflowException) as exc_info: store.get_experiment_by_name("abc") assert exc_info.value.error_code == ErrorCode.Name(INTERNAL_ERROR) assert exc_info.value.message == "Some internal error!" expected_message0 = GetExperimentByName(experiment_name="abc") self._verify_requests( mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message0), ) assert mock_http.call_count == 1
def test_http_request_request_headers(request): """This test requires the package in tests/resources/mlflow-test-plugin to be installed""" from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider # The test plugin's request header provider always returns False from in_context to avoid # polluting request headers in developers' environments. The following mock overrides this to # perform the integration test. with mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True): host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint") request.assert_called_with( url="http://my-host/my/endpoint", verify="/some/path", headers={ **_DEFAULT_HEADERS, "test": "header" }, )
def test_requestor(self, request): response = mock.MagicMock response.status_code = 200 response.text = '{}' request.return_value = response creds = MlflowHostCreds('https://hello') store = RestStore(lambda: creds) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.log_param("some_uuid", Param("k1", "v1")) body = message_to_json( LogParam(run_uuid="some_uuid", key="k1", value="v1")) self._verify_requests(mock_http, creds, "runs/log-parameter", "POST", body) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000)) body = message_to_json( SetTag(run_uuid="some_uuid", key="t1", value="abcd" * 1000)) self._verify_requests(mock_http, creds, "runs/set-tag", "POST", body) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.log_metric("u2", Metric("m1", 0.87, 12345)) body = message_to_json( LogMetric(run_uuid="u2", key="m1", value=0.87, timestamp=12345)) self._verify_requests(mock_http, creds, "runs/log-metric", "POST", body) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.delete_run("u25") self._verify_requests(mock_http, creds, "runs/delete", "POST", message_to_json(DeleteRun(run_id="u25"))) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.restore_run("u76") self._verify_requests(mock_http, creds, "runs/restore", "POST", message_to_json(RestoreRun(run_id="u76"))) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.delete_experiment(0) self._verify_requests( mock_http, creds, "experiments/delete", "POST", message_to_json(DeleteExperiment(experiment_id=0))) with mock.patch( 'mlflow.store.rest_store.http_request_safe') as mock_http: store.restore_experiment(0) self._verify_requests( mock_http, creds, "experiments/restore", "POST", message_to_json(RestoreExperiment(experiment_id=0)))
def test_init_validation_and_cleaning(self): with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store') \ as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds('http://host') repo = get_artifact_repository('dbfs:/test/') assert repo.artifact_uri == 'dbfs:/test' with pytest.raises(MlflowException): DbfsRestArtifactRepository('s3://test')
def test_malformed_json_error_response(response_mock): with mock.patch("requests.request") as request_mock: host_only = MlflowHostCreds("http://my-host") request_mock.return_value = response_mock response_proto = GetRun.Response() with pytest.raises(MlflowException): call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
def test_http_request_hostonly(request): host_only = MlflowHostCreds("http://my-host") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint") request.assert_called_with( url="http://my-host/my/endpoint", verify=True, headers=_DEFAULT_HEADERS, )
def test_failed_http_request_custom_handler(self, request): response = mock.MagicMock response.status_code = 404 response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}' request.return_value = response store = CustomErrorHandlingRestStore(lambda: MlflowHostCreds('https://hello')) with pytest.raises(MyCoolException): store.list_experiments()
def test_init_validation_and_cleaning(self): with mock.patch('mlflow.store.dbfs_artifact_repo._get_host_creds_from_default_store') \ as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds('http://host' ) repo = DbfsArtifactRepository('dbfs:/test/') assert repo.artifact_uri == 'dbfs:/test' with pytest.raises(MlflowException): DbfsArtifactRepository('s3://test')
def test_failed_http_request(self, request): response = mock.MagicMock response.status_code = 404 response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}' request.return_value = response store = RestStore(lambda: MlflowHostCreds('https://hello')) with pytest.raises(MlflowException) as cm: store.list_experiments() assert "RESOURCE_DOES_NOT_EXIST: No experiment" in str(cm.value)
def test_http_request_cleans_hostname(request): # Add a trailing slash, should be removed. host_only = MlflowHostCreds("http://my-host/") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint") request.assert_called_with( url="http://my-host/my/endpoint", verify=True, headers=_DEFAULT_HEADERS, )
def test_http_request_with_basic_auth(request): host_only = MlflowHostCreds("http://my-host", username="******", password="******") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint") headers = dict(_DEFAULT_HEADERS) headers["Authorization"] = "Basic dXNlcjpwYXNz" request.assert_called_with( url="http://my-host/my/endpoint", verify=True, headers=headers, )
def test_http_request_with_token(request): host_only = MlflowHostCreds("http://my-host", token="my-token") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint") headers = dict(_DEFAULT_HEADERS) headers["Authorization"] = "Bearer my-token" request.assert_called_with( url="http://my-host/my/endpoint", verify=True, headers=headers, )
def test_ignore_tls_verification_not_server_cert_path(): with pytest.raises( MlflowException, match= "When 'ignore_tls_verification' is true then 'server_cert_path' must not be set", ): MlflowHostCreds( "http://my-host", ignore_tls_verification=True, server_cert_path="/some/path", )
def test_well_formed_json_error_response(): with mock.patch("requests.request") as request_mock: host_only = MlflowHostCreds("http://my-host") response_mock = mock.MagicMock() response_mock.status_code = 400 response_mock.text = "{}" # well-formed JSON error response request_mock.return_value = response_mock response_proto = GetRun.Response() with pytest.raises(RestException): call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
def test_http_request_with_insecure(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify=False, headers=_DEFAULT_HEADERS, )
def test_http_request_server_cert_path(request): host_only = MlflowHostCreds("http://my-host", server_cert_path='/some/path') response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify='/some/path', headers=_DEFAULT_HEADERS, )
def test_init_get_host_creds_with_databricks_profile_uri(self): databricks_host = "https://something.databricks.com" default_host = "http://host" with mock.patch( DBFS_ARTIFACT_REPOSITORY_PACKAGE + "._get_host_creds_from_default_store", return_value=lambda: MlflowHostCreds(default_host), ), mock.patch( DBFS_ARTIFACT_REPOSITORY_PACKAGE + ".get_databricks_host_creds", return_value=MlflowHostCreds(databricks_host), ): repo = DbfsRestArtifactRepository( "dbfs://profile@databricks/test/") assert repo.artifact_uri == "dbfs:/test/" creds = repo.get_host_creds() assert creds.host == databricks_host # no databricks_profile_uri given repo = DbfsRestArtifactRepository("dbfs:/test/") creds = repo.get_host_creds() assert creds.host == default_host
def test_http_request_with_token(request): host_only = MlflowHostCreds("http://my-host", token='my-token') response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') request.assert_called_with( url='http://my-host/my/endpoint', verify=True, headers={ 'Authorization': 'Bearer my-token' }, )
def test_http_request_with_basic_auth(request): host_only = MlflowHostCreds("http://my-host", username='******', password='******') response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, '/my/endpoint') headers = dict(_DEFAULT_HEADERS) headers['Authorization'] = 'Basic dXNlcjpwYXNz' request.assert_called_with( url='http://my-host/my/endpoint', verify=True, headers=headers, )
def test_http_request_server_cert_path(request): host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint", "GET") request.assert_called_with( "GET", "http://my-host/my/endpoint", verify="/some/path", headers=_DEFAULT_HEADERS, timeout=120, )
def test_init_validation_and_cleaning(self): with mock.patch( DBFS_ARTIFACT_REPOSITORY_PACKAGE + "._get_host_creds_from_default_store") as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds("http://host" ) repo = get_artifact_repository("dbfs:/test/") assert repo.artifact_uri == "dbfs:/test" with pytest.raises(MlflowException): DbfsRestArtifactRepository("s3://test") with pytest.raises(MlflowException): DbfsRestArtifactRepository( "dbfs://profile@notdatabricks/test/")
def test_http_request_with_insecure(request): host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) response = mock.MagicMock() response.status_code = 200 request.return_value = response http_request(host_only, "/my/endpoint", "GET") request.assert_called_with( "GET", "http://my-host/my/endpoint", verify=False, headers=DefaultRequestHeaderProvider().request_headers(), timeout=120, )