コード例 #1
0
 def test_init_get_host_creds_with_databricks_profile_uri(self):
     databricks_host = 'https://something.databricks.com'
     default_host = 'http://host'
     with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store',
                     return_value=lambda: MlflowHostCreds(default_host)), \
             mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '.get_databricks_host_creds',
                        return_value=MlflowHostCreds(databricks_host)):
         repo = DbfsRestArtifactRepository(
             'dbfs://profile@databricks/test/')
         assert repo.artifact_uri == 'dbfs:/test/'
         creds = repo.get_host_creds()
         assert creds.host == databricks_host
         # no databricks_profile_uri given
         repo = DbfsRestArtifactRepository('dbfs:/test/')
         creds = repo.get_host_creds()
         assert creds.host == default_host
コード例 #2
0
def test_http_request_wrapper(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
    response = mock.MagicMock()
    response.status_code = 200
    response.text = "{}"
    request.return_value = response
    http_request_safe(host_only, "/my/endpoint")
    request.assert_called_with(
        url="http://my-host/my/endpoint",
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
    response.text = "non json"
    request.return_value = response
    http_request_safe(host_only, "/my/endpoint")
    request.assert_called_with(
        url="http://my-host/my/endpoint",
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
    response.status_code = 400
    response.text = ""
    request.return_value = response
    with pytest.raises(MlflowException, match="Response body"):
        http_request_safe(host_only, "/my/endpoint")
    response.text = (
        '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
    )
    request.return_value = response
    with pytest.raises(
            RestException,
            match="RESOURCE_DOES_NOT_EXIST: Node type not supported"):
        http_request_safe(host_only, "/my/endpoint")
コード例 #3
0
    def test_databricks_paginate_list_experiments(self):
        creds = MlflowHostCreds("https://hello")
        store = DatabricksRestStore(lambda: creds)

        list_exp_responses = []
        next_page_tokens = ["a", "b", None]
        for next_page_token in next_page_tokens:
            experiment = Experiment(
                experiment_id="123",
                name=str(next_page_token),
                artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE,
            )
            list_exp_response = mock.MagicMock()
            list_exp_response.text = json.dumps({
                "experiments":
                [json.loads(message_to_json(experiment.to_proto()))],
                "next_page_token":
                next_page_token,
            })
            list_exp_response.status_code = 200
            list_exp_responses.append(list_exp_response)

        with mock.patch("mlflow.utils.rest_utils.http_request",
                        side_effect=list_exp_responses):
            for idx, experiments in enumerate(
                    store._paginate_list_experiments(ViewType.ACTIVE_ONLY)):
                assert experiments[0].name == str(next_page_tokens[idx])
                assert experiments.token == next_page_tokens[idx]
コード例 #4
0
def test_ignore_tls_verification_not_server_cert_path():
    with pytest.raises(MlflowException):
        MlflowHostCreds(
            "http://my-host",
            ignore_tls_verification=True,
            server_cert_path="/some/path",
        )
コード例 #5
0
def test_get_host_creds_from_default_store_rest_store():
    with mock.patch("mlflow.tracking._tracking_service.utils._get_store"
                    ) as get_store_mock:
        get_store_mock.return_value = RestStore(
            lambda: MlflowHostCreds("http://host"))
        assert isinstance(_get_host_creds_from_default_store()(),
                          MlflowHostCreds)
コード例 #6
0
def test_429_retries(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)

    class MockedResponse(object):
        def __init__(self, status_code):
            self.status_code = status_code
            self.text = "mocked text"

    request.side_effect = [MockedResponse(x) for x in (429, 200)]
    assert http_request(host_only, "/my/endpoint",
                        max_rate_limit_interval=0).status_code == 429
    request.side_effect = [MockedResponse(x) for x in (429, 200)]
    assert http_request(host_only, "/my/endpoint",
                        max_rate_limit_interval=1).status_code == 200
    request.side_effect = [MockedResponse(x) for x in (429, 429, 200)]
    assert http_request(host_only, "/my/endpoint",
                        max_rate_limit_interval=1).status_code == 429
    request.side_effect = [MockedResponse(x) for x in (429, 429, 200)]
    assert http_request(host_only, "/my/endpoint",
                        max_rate_limit_interval=2).status_code == 200
    request.side_effect = [MockedResponse(x) for x in (429, 429, 200)]
    assert http_request(host_only, "/my/endpoint",
                        max_rate_limit_interval=3).status_code == 200
    # Test that any non 429 code is returned
    request.side_effect = [MockedResponse(x) for x in (429, 404, 429, 200)]
    assert http_request(host_only, "/my/endpoint").status_code == 404
    # Test that retries work as expected
    request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)]
    with pytest.raises(MlflowException, match="failed to return code 200"):
        http_request(host_only, "/my/endpoint", retries=1)
    request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)]
    assert http_request(host_only, "/my/endpoint",
                        retries=2).status_code == 200
コード例 #7
0
def test_default_host_creds():

    artifact_uri = "https://test.com"
    username = "******"
    password = "******"
    token = "token"
    ignore_tls_verification = False
    client_cert_path = "client_cert_path"
    server_cert_path = "server_cert_path"

    expected_host_creds = MlflowHostCreds(
        host=artifact_uri,
        username=username,
        password=password,
        token=token,
        ignore_tls_verification=ignore_tls_verification,
        client_cert_path=client_cert_path,
        server_cert_path=server_cert_path,
    )

    repo = HttpArtifactRepository(artifact_uri)

    with mock.patch.dict(
            "mlflow.tracking._tracking_service.utils.os.environ",
        {
            _TRACKING_USERNAME_ENV_VAR: username,
            _TRACKING_PASSWORD_ENV_VAR: password,
            _TRACKING_TOKEN_ENV_VAR: token,
            _TRACKING_INSECURE_TLS_ENV_VAR: str(ignore_tls_verification),
            _TRACKING_CLIENT_CERT_PATH_ENV_VAR: client_cert_path,
            _TRACKING_SERVER_CERT_PATH_ENV_VAR: server_cert_path,
        },
    ):
        assert repo._host_creds == expected_host_creds
コード例 #8
0
ファイル: test_rest_utils.py プロジェクト: bkbonde/mlflow
def test_http_request_request_headers_user_agent(request):
    """This test requires the package in tests/resources/mlflow-test-plugin to be installed"""

    from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider

    # The test plugin's request header provider always returns False from in_context to avoid
    # polluting request headers in developers' environments. The following mock overrides this to
    # perform the integration test.
    with mock.patch.object(PluginRequestHeaderProvider,
                           "in_context",
                           return_value=True), mock.patch.object(
                               PluginRequestHeaderProvider,
                               "request_headers",
                               return_value={_USER_AGENT: "test_user_agent"},
                           ):
        host_only = MlflowHostCreds("http://my-host",
                                    server_cert_path="/some/path")
        expected_headers = {
            _USER_AGENT:
            "{} {}".format(
                DefaultRequestHeaderProvider().request_headers()[_USER_AGENT],
                "test_user_agent")
        }

        response = mock.MagicMock()
        response.status_code = 200
        request.return_value = response
        http_request(host_only, "/my/endpoint", "GET")
        request.assert_called_with(
            "GET",
            "http://my-host/my/endpoint",
            verify="/some/path",
            headers=expected_headers,
            timeout=120,
        )
コード例 #9
0
    def test_databricks_rest_store_get_experiment_by_name(self):
        creds = MlflowHostCreds("https://hello")
        store = DatabricksRestStore(lambda: creds)
        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            # Verify that Databricks REST client won't fall back to ListExperiments for 500-level
            # errors that are not ENDPOINT_NOT_FOUND

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Some internal error!", INTERNAL_ERROR)

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("abc")
            assert exc_info.value.error_code == ErrorCode.Name(INTERNAL_ERROR)
            assert exc_info.value.message == "Some internal error!"
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message0),
            )
            assert mock_http.call_count == 1
コード例 #10
0
def test_http_request_request_headers(request):
    """This test requires the package in tests/resources/mlflow-test-plugin to be installed"""

    from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider

    # The test plugin's request header provider always returns False from in_context to avoid
    # polluting request headers in developers' environments. The following mock overrides this to
    # perform the integration test.
    with mock.patch.object(PluginRequestHeaderProvider,
                           "in_context",
                           return_value=True):
        host_only = MlflowHostCreds("http://my-host",
                                    server_cert_path="/some/path")

        response = mock.MagicMock()
        response.status_code = 200
        request.return_value = response
        http_request(host_only, "/my/endpoint")
        request.assert_called_with(
            url="http://my-host/my/endpoint",
            verify="/some/path",
            headers={
                **_DEFAULT_HEADERS, "test": "header"
            },
        )
コード例 #11
0
ファイル: test_rest_store.py プロジェクト: zorrotrying/mlflow
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = '{}'
        request.return_value = response

        creds = MlflowHostCreds('https://hello')
        store = RestStore(lambda: creds)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid", key="k1", value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid", key="t1", value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345))
            body = message_to_json(
                LogMetric(run_uuid="u2", key="m1", value=0.87,
                          timestamp=12345))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.delete_experiment(0)
            self._verify_requests(
                mock_http, creds, "experiments/delete", "POST",
                message_to_json(DeleteExperiment(experiment_id=0)))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.restore_experiment(0)
            self._verify_requests(
                mock_http, creds, "experiments/restore", "POST",
                message_to_json(RestoreExperiment(experiment_id=0)))
コード例 #12
0
 def test_init_validation_and_cleaning(self):
     with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store') \
             as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds('http://host')
         repo = get_artifact_repository('dbfs:/test/')
         assert repo.artifact_uri == 'dbfs:/test'
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository('s3://test')
コード例 #13
0
def test_malformed_json_error_response(response_mock):
    with mock.patch("requests.request") as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(MlflowException):
            call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
def test_http_request_hostonly(request):
    host_only = MlflowHostCreds("http://my-host")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, "/my/endpoint")
    request.assert_called_with(
        url="http://my-host/my/endpoint", verify=True, headers=_DEFAULT_HEADERS,
    )
コード例 #15
0
    def test_failed_http_request_custom_handler(self, request):
        response = mock.MagicMock
        response.status_code = 404
        response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}'
        request.return_value = response

        store = CustomErrorHandlingRestStore(lambda: MlflowHostCreds('https://hello'))
        with pytest.raises(MyCoolException):
            store.list_experiments()
コード例 #16
0
ファイル: test_dbfs_artifact_repo.py プロジェクト: zge/mlflow
 def test_init_validation_and_cleaning(self):
     with mock.patch('mlflow.store.dbfs_artifact_repo._get_host_creds_from_default_store') \
             as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds('http://host'
                                                               )
         repo = DbfsArtifactRepository('dbfs:/test/')
         assert repo.artifact_uri == 'dbfs:/test'
         with pytest.raises(MlflowException):
             DbfsArtifactRepository('s3://test')
コード例 #17
0
    def test_failed_http_request(self, request):
        response = mock.MagicMock
        response.status_code = 404
        response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}'
        request.return_value = response

        store = RestStore(lambda: MlflowHostCreds('https://hello'))
        with pytest.raises(MlflowException) as cm:
            store.list_experiments()
        assert "RESOURCE_DOES_NOT_EXIST: No experiment" in str(cm.value)
def test_http_request_cleans_hostname(request):
    # Add a trailing slash, should be removed.
    host_only = MlflowHostCreds("http://my-host/")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, "/my/endpoint")
    request.assert_called_with(
        url="http://my-host/my/endpoint", verify=True, headers=_DEFAULT_HEADERS,
    )
def test_http_request_with_basic_auth(request):
    host_only = MlflowHostCreds("http://my-host", username="******", password="******")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, "/my/endpoint")
    headers = dict(_DEFAULT_HEADERS)
    headers["Authorization"] = "Basic dXNlcjpwYXNz"
    request.assert_called_with(
        url="http://my-host/my/endpoint", verify=True, headers=headers,
    )
def test_http_request_with_token(request):
    host_only = MlflowHostCreds("http://my-host", token="my-token")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, "/my/endpoint")
    headers = dict(_DEFAULT_HEADERS)
    headers["Authorization"] = "Bearer my-token"
    request.assert_called_with(
        url="http://my-host/my/endpoint", verify=True, headers=headers,
    )
コード例 #21
0
ファイル: test_rest_utils.py プロジェクト: bkbonde/mlflow
def test_ignore_tls_verification_not_server_cert_path():
    with pytest.raises(
            MlflowException,
            match=
            "When 'ignore_tls_verification' is true then 'server_cert_path' must not be set",
    ):
        MlflowHostCreds(
            "http://my-host",
            ignore_tls_verification=True,
            server_cert_path="/some/path",
        )
コード例 #22
0
def test_well_formed_json_error_response():
    with mock.patch("requests.request") as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        response_mock = mock.MagicMock()
        response_mock.status_code = 400
        response_mock.text = "{}"  # well-formed JSON error response
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(RestException):
            call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
コード例 #23
0
def test_http_request_with_insecure(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
コード例 #24
0
def test_http_request_server_cert_path(request):
    host_only = MlflowHostCreds("http://my-host",
                                server_cert_path='/some/path')
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify='/some/path',
        headers=_DEFAULT_HEADERS,
    )
コード例 #25
0
 def test_init_get_host_creds_with_databricks_profile_uri(self):
     databricks_host = "https://something.databricks.com"
     default_host = "http://host"
     with mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             "._get_host_creds_from_default_store",
             return_value=lambda: MlflowHostCreds(default_host),
     ), mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             ".get_databricks_host_creds",
             return_value=MlflowHostCreds(databricks_host),
     ):
         repo = DbfsRestArtifactRepository(
             "dbfs://profile@databricks/test/")
         assert repo.artifact_uri == "dbfs:/test/"
         creds = repo.get_host_creds()
         assert creds.host == databricks_host
         # no databricks_profile_uri given
         repo = DbfsRestArtifactRepository("dbfs:/test/")
         creds = repo.get_host_creds()
         assert creds.host == default_host
コード例 #26
0
def test_http_request_with_token(request):
    host_only = MlflowHostCreds("http://my-host", token='my-token')
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=True,
        headers={
            'Authorization': 'Bearer my-token'
        },
    )
コード例 #27
0
def test_http_request_with_basic_auth(request):
    host_only = MlflowHostCreds("http://my-host", username='******', password='******')
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    headers = dict(_DEFAULT_HEADERS)
    headers['Authorization'] = 'Basic dXNlcjpwYXNz'
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=True,
        headers=headers,
    )
コード例 #28
0
def test_http_request_server_cert_path(request):
    host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, "/my/endpoint", "GET")
    request.assert_called_with(
        "GET",
        "http://my-host/my/endpoint",
        verify="/some/path",
        headers=_DEFAULT_HEADERS,
        timeout=120,
    )
コード例 #29
0
 def test_init_validation_and_cleaning(self):
     with mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             "._get_host_creds_from_default_store") as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds("http://host"
                                                               )
         repo = get_artifact_repository("dbfs:/test/")
         assert repo.artifact_uri == "dbfs:/test"
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository("s3://test")
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository(
                 "dbfs://profile@notdatabricks/test/")
コード例 #30
0
ファイル: test_rest_utils.py プロジェクト: bkbonde/mlflow
def test_http_request_with_insecure(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, "/my/endpoint", "GET")
    request.assert_called_with(
        "GET",
        "http://my-host/my/endpoint",
        verify=False,
        headers=DefaultRequestHeaderProvider().request_headers(),
        timeout=120,
    )