コード例 #1
0
ファイル: databricks_utils.py プロジェクト: iPieter/kiwi
def get_databricks_host_creds(server_uri=None):
    """
    Reads in configuration necessary to make HTTP requests to a Databricks server. This
    uses the Databricks CLI's ConfigProvider interface to load the DatabricksConfig object.
    If no Databricks CLI profile is found corresponding to the server URI, this function
    will attempt to retrieve these credentials from the Databricks Secret Manager. For that to work,
    the server URI will need to be of the following format: "databricks://profile/prefix". In the
    Databricks Secret Manager, we will query for a secret in the scope "<profile>" for secrets with
    keys of the form "<prefix>-host" and "<prefix>-token". Note that this prefix *cannot* be empty
    if trying to authenticate with this method. If found, those host credentials will be used. This
    method will throw an exception if sufficient auth cannot be found.

    :param server_uri: A URI that specifies the Databricks profile you want to use for making
    requests.
    :return: :py:class:`mlflow.rest_utils.MlflowHostCreds` which includes the hostname and
        authentication information necessary to talk to the Databricks server.
    """
    profile, path = get_db_info_from_uri(server_uri)
    if not hasattr(provider, 'get_config'):
        _logger.warning(
            "Support for databricks-cli<0.8.0 is deprecated and will be removed"
            " in a future version.")
        config = provider.get_config_for_profile(profile)
    elif profile:
        config = provider.ProfileConfigProvider(profile).get_config()
    else:
        config = provider.get_config()
    # if a path is specified, that implies a Databricks tracking URI of the form:
    # databricks://profile-name/path-specifier
    if (not config or not config.host) and path:
        dbutils = _get_dbutils()
        if dbutils:
            # Prefix differentiates users and is provided as path information in the URI
            key_prefix = path
            host = dbutils.secrets.get(scope=profile, key=key_prefix + "-host")
            token = dbutils.secrets.get(scope=profile,
                                        key=key_prefix + "-token")
            if host and token:
                config = provider.DatabricksConfig.from_token(host=host,
                                                              token=token,
                                                              insecure=False)
    if not config or not config.host:
        _fail_malformed_databricks_auth(profile)

    insecure = hasattr(config, 'insecure') and config.insecure

    if config.username is not None and config.password is not None:
        return MlflowHostCreds(config.host,
                               username=config.username,
                               password=config.password,
                               ignore_tls_verification=insecure)
    elif config.token:
        return MlflowHostCreds(config.host,
                               token=config.token,
                               ignore_tls_verification=insecure)
    _fail_malformed_databricks_auth(profile)
コード例 #2
0
def test_429_retries(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)

    class MockedResponse(object):
        def __init__(self, status_code):
            self.status_code = status_code
            self.text = "mocked text"

    request.side_effect = [MockedResponse(x) for x in (429, 200)]
    assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=0).status_code == 429
    request.side_effect = [MockedResponse(x) for x in (429, 200)]
    assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=1).status_code == 200
    request.side_effect = [MockedResponse(x) for x in (429, 429, 200)]
    assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=1).status_code == 429
    request.side_effect = [MockedResponse(x) for x in (429, 429, 200)]
    assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=2).status_code == 200
    request.side_effect = [MockedResponse(x) for x in (429, 429, 200)]
    assert http_request(host_only, '/my/endpoint', max_rate_limit_interval=3).status_code == 200
    # Test that any non 429 code is returned
    request.side_effect = [MockedResponse(x) for x in (429, 404, 429, 200)]
    assert http_request(host_only, '/my/endpoint').status_code == 404
    # Test that retries work as expected
    request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)]
    with pytest.raises(MlflowException, match="failed to return code 200"):
        http_request(host_only, '/my/endpoint', retries=1)
    request.side_effect = [MockedResponse(x) for x in (429, 503, 429, 200)]
    assert http_request(host_only, '/my/endpoint', retries=2).status_code == 200
コード例 #3
0
def test_ignore_tls_verification_not_server_cert_path():
    with pytest.raises(MlflowException):
        MlflowHostCreds(
            "http://my-host",
            ignore_tls_verification=True,
            server_cert_path='/some/path',
        )
コード例 #4
0
def test_malformed_json_error_response(response_mock):
    with mock.patch('requests.request') as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(MlflowException):
            call_endpoint(host_only, '/my/endpoint', 'GET', "", response_proto)
コード例 #5
0
 def test_init_validation_and_cleaning(self):
     with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store') \
             as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds('http://host')
         repo = get_artifact_repository('dbfs:/test/')
         assert repo.artifact_uri == 'dbfs:/test'
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository('s3://test')
コード例 #6
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_failed_http_request_custom_handler(self, request):
        response = mock.MagicMock
        response.status_code = 404
        response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}'
        request.return_value = response

        store = CustomErrorHandlingRestStore(lambda: MlflowHostCreds('https://hello'))
        with pytest.raises(MyCoolException):
            store.list_experiments()
コード例 #7
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_failed_http_request(self, request):
        response = mock.MagicMock
        response.status_code = 404
        response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}'
        request.return_value = response

        store = RestStore(lambda: MlflowHostCreds('https://hello'))
        with pytest.raises(MlflowException) as cm:
            store.list_experiments()
        assert "RESOURCE_DOES_NOT_EXIST: No experiment" in str(cm.value)
コード例 #8
0
def test_http_request_hostonly(request):
    host_only = MlflowHostCreds("http://my-host")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=True,
        headers=_DEFAULT_HEADERS,
    )
コード例 #9
0
def test_well_formed_json_error_response():
    with mock.patch('requests.request') as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        response_mock = mock.MagicMock()
        response_mock.status_code = 400
        response_mock.text = "{}"  # well-formed JSON error response
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(RestException):
            call_endpoint(host_only, '/my/endpoint', 'GET', "", response_proto)
コード例 #10
0
def test_http_request_with_insecure(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
コード例 #11
0
def test_http_request_cleans_hostname(request):
    # Add a trailing slash, should be removed.
    host_only = MlflowHostCreds("http://my-host/")
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=True,
        headers=_DEFAULT_HEADERS,
    )
コード例 #12
0
def test_http_request_with_basic_auth(request):
    host_only = MlflowHostCreds("http://my-host", username='******', password='******')
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    headers = dict(_DEFAULT_HEADERS)
    headers['Authorization'] = 'Basic dXNlcjpwYXNz'
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=True,
        headers=headers,
    )
コード例 #13
0
def test_http_request_with_token(request):
    host_only = MlflowHostCreds("http://my-host", token='my-token')
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request(host_only, '/my/endpoint')
    headers = dict(_DEFAULT_HEADERS)
    headers['Authorization'] = 'Bearer my-token'
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=True,
        headers=headers,
    )
コード例 #14
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_response_with_unknown_fields(self, request):
        experiment_json = {
            "experiment_id": "1",
            "name": "My experiment",
            "artifact_location": "foo",
            "lifecycle_stage": "deleted",
            "OMG_WHAT_IS_THIS_FIELD": "Hooly cow",
        }

        response = mock.MagicMock
        response.status_code = 200
        experiments = {"experiments": [experiment_json]}
        response.text = json.dumps(experiments)
        request.return_value = response

        store = RestStore(lambda: MlflowHostCreds('https://hello'))
        experiments = store.list_experiments()
        assert len(experiments) == 1
        assert experiments[0].name == 'My experiment'
コード例 #15
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_databricks_rest_store_get_experiment_by_name(self):
        creds = MlflowHostCreds('https://hello')
        store = DatabricksRestStore(lambda: creds)
        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            # Verify that Databricks REST client won't fall back to ListExperiments for 500-level
            # errors that are not ENDPOINT_NOT_FOUND

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Some internal error!", INTERNAL_ERROR)
            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("abc")
            assert exc_info.value.error_code == ErrorCode.Name(INTERNAL_ERROR)
            assert exc_info.value.message == "Some internal error!"
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            assert mock_http.call_count == 1
コード例 #16
0
def test_http_request_wrapper(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
    response = mock.MagicMock()
    response.status_code = 200
    request.return_value = response
    http_request_safe(host_only, '/my/endpoint')
    request.assert_called_with(
        url='http://my-host/my/endpoint',
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
    response.status_code = 400
    response.text = ""
    request.return_value = response
    with pytest.raises(MlflowException, match="Response body"):
        http_request_safe(host_only, '/my/endpoint')
    response.text = \
        '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
    request.return_value = response
    with pytest.raises(RestException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported"):
        http_request_safe(host_only, '/my/endpoint')
コード例 #17
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_successful_http_request(self, request):
        def mock_request(**kwargs):
            # Filter out None arguments
            kwargs = dict((k, v) for k, v in six.iteritems(kwargs) if v is not None)
            assert kwargs == {
                'method': 'GET',
                'params': {'view_type': 'ACTIVE_ONLY'},
                'url': 'https://hello/api/2.0/mlflow/experiments/list',
                'headers': _DEFAULT_HEADERS,
                'verify': True,
            }
            response = mock.MagicMock
            response.status_code = 200
            response.text = '{"experiments": [{"name": "Exp!", "lifecycle_stage": "active"}]}'
            return response

        request.side_effect = mock_request

        store = RestStore(lambda: MlflowHostCreds('https://hello'))
        experiments = store.list_experiments()
        assert experiments[0].name == "Exp!"
コード例 #18
0
def host_creds_mock():
    with mock.patch('mlflow.store.artifact.dbfs_artifact_repo._get_host_creds_from_default_store') \
            as get_creds_mock:
        get_creds_mock.return_value = lambda: MlflowHostCreds('http://host')
        yield
コード例 #19
0
def test_get_host_creds_from_default_store_rest_store():
    with mock.patch('mlflow.tracking._tracking_service.utils._get_store') as get_store_mock:
        get_store_mock.return_value = RestStore(lambda: MlflowHostCreds('http://host'))
        assert isinstance(_get_host_creds_from_default_store()(), MlflowHostCreds)
コード例 #20
0
def dbfs_artifact_repo():
    with mock.patch('mlflow.store.artifact.dbfs_artifact_repo._get_host_creds_from_default_store') \
            as get_creds_mock:
        get_creds_mock.return_value = lambda: MlflowHostCreds('http://host')
        return get_artifact_repository('dbfs:/test/')
コード例 #21
0
 def setUp(self):
     self.creds = MlflowHostCreds('https://hello')
     self.store = RestStore(lambda: self.creds)
コード例 #22
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = '{}'
        request.return_value = response

        creds = MlflowHostCreds('https://hello')
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_name", return_value=source_name
        )
        source_type_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_type",
            return_value=SourceType.LOCAL
        )
        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http, \
                mock.patch('mlflow.tracking._tracking_service.utils._get_store',
                           return_value=store), \
                mock.patch('mlflow.tracking.context.default_context._get_user',
                           return_value=user_name), \
                mock.patch('time.time', return_value=13579), \
                source_name_patch, source_type_patch:
            with kiwi.start_run(experiment_id="43"):
                cr_body = message_to_json(CreateRun(experiment_id="43",
                                                    user_id=user_name, start_time=13579000,
                                                    tags=[ProtoRunTag(key='mlflow.source.name',
                                                                      value=source_name),
                                                          ProtoRunTag(key='mlflow.source.type',
                                                                      value='LOCAL'),
                                                          ProtoRunTag(key='mlflow.user',
                                                                      value=user_name)]))
                expected_kwargs = self._args(creds, "runs/create", "POST", cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs['json'].pop('tags')
                actual_tags = actual_kwargs['json'].pop('tags')
                assert (
                    sorted(expected_tags, key=lambda t: t['key']) ==
                    sorted(actual_tags, key=lambda t: t['key'])
                )
                assert expected_kwargs == actual_kwargs

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(LogParam(
                run_uuid="some_uuid", run_id="some_uuid", key="k1", value="v1"))
            self._verify_requests(mock_http, creds,
                                  "runs/log-parameter", "POST", body)

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.set_experiment_tag("some_id", ExperimentTag("t1", "abcd"*1000))
            body = message_to_json(SetExperimentTag(
                experiment_id="some_id",
                key="t1",
                value="abcd"*1000))
            self._verify_requests(mock_http, creds,
                                  "experiments/set-experiment-tag", "POST", body)

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd"*1000))
            body = message_to_json(SetTag(
                run_uuid="some_uuid", run_id="some_uuid", key="t1", value="abcd"*1000))
            self._verify_requests(mock_http, creds,
                                  "runs/set-tag", "POST", body)

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.delete_tag("some_uuid", "t1")
            body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
            self._verify_requests(mock_http, creds,
                                  "runs/delete-tag", "POST", body)

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(LogMetric(
                run_uuid="u2", run_id="u2", key="m1", value=0.87, timestamp=12345, step=3))
            self._verify_requests(mock_http, creds,
                                  "runs/log-metric", "POST", body)

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            metrics = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, -1),
                       Metric("m3", 0.58, 12345, 2)]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2", metrics=metrics, params=params, tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(LogBatch(run_id="u2", metrics=metric_protos,
                                            params=param_protos, tags=tag_protos))
            self._verify_requests(mock_http, creds,
                                  "runs/log-batch", "POST", body)

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds,
                                  "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds,
                                  "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.delete_experiment("0")
            self._verify_requests(mock_http, creds,
                                  "experiments/delete", "POST",
                                  message_to_json(DeleteExperiment(experiment_id="0")))

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            store.restore_experiment("0")
            self._verify_requests(mock_http, creds,
                                  "experiments/restore", "POST",
                                  message_to_json(RestoreExperiment(experiment_id="0")))

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            response = mock.MagicMock
            response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}'
            mock_http.return_value = response
            result = store.search_runs(["0", "1"], "params.p1 = 'a'", ViewType.ACTIVE_ONLY,
                                       max_results=10, order_by=["a"], page_token="12345abcde")

            expected_message = SearchRuns(experiment_ids=["0", "1"], filter="params.p1 = 'a'",
                                          run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY),
                                          max_results=10, order_by=["a"], page_token="12345abcde")
            self._verify_requests(mock_http, creds,
                                  "runs/search", "POST",
                                  message_to_json(expected_message))
            assert result.token == "67890fghij"

        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            run_id = "run_id"
            m = Model(artifact_path="model/path", run_id="run_id", flavors={"tf": "flavor body"})
            result = store.record_logged_model("run_id", m)
            expected_message = LogModel(run_id=run_id,
                                        model_json=m.to_json())
            self._verify_requests(mock_http, creds,
                                  "runs/log-model", "POST",
                                  message_to_json(expected_message))
コード例 #23
0
ファイル: test_rest_store.py プロジェクト: iPieter/kiwi
    def test_get_experiment_by_name(self, store_class):
        creds = MlflowHostCreds('https://hello')
        store = store_class(lambda: creds)
        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            response = mock.MagicMock
            response.status_code = 200
            experiment = Experiment(
                experiment_id="123", name="abc", artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE)
            response.text = json.dumps({
                "experiment": json.loads(message_to_json(experiment.to_proto()))})
            mock_http.return_value = response
            result = store.get_experiment_by_name("abc")
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage
            # Test GetExperimentByName against nonexistent experiment
            mock_http.reset_mock()
            nonexistent_exp_response = mock.MagicMock
            nonexistent_exp_response.status_code = 404
            nonexistent_exp_response.text =\
                MlflowException("Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST).serialize_as_json()
            mock_http.return_value = nonexistent_exp_response
            assert store.get_experiment_by_name("nonexistent-experiment") is None
            expected_message1 = GetExperimentByName(experiment_name="nonexistent-experiment")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message1))
            assert mock_http.call_count == 1

            # Test REST client behavior against a mocked old server, which has handler for
            # ListExperiments but not GetExperimentByName
            mock_http.reset_mock()
            list_exp_response = mock.MagicMock
            list_exp_response.text = json.dumps({
                "experiments": [json.loads(message_to_json(experiment.to_proto()))]})
            list_exp_response.status_code = 200

            def response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                if kwargs.get('endpoint') == "/api/2.0/mlflow/experiments/get-by-name":
                    raise MlflowException("GetExperimentByName is not implemented",
                                          ENDPOINT_NOT_FOUND)
                else:
                    return list_exp_response

            mock_http.side_effect = response_fn
            result = store.get_experiment_by_name("abc")
            expected_message2 = ListExperiments(view_type=ViewType.ALL)
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            self._verify_requests(mock_http, creds,
                                  "experiments/list", "GET",
                                  message_to_json(expected_message2))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage

            # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting
            # rate limits)
            mock_http.reset_mock()

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Hit rate limit on GetExperimentByName",
                                      REQUEST_LIMIT_EXCEEDED)

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("imspamming")
            assert exc_info.value.error_code == ErrorCode.Name(REQUEST_LIMIT_EXCEEDED)
            assert mock_http.call_count == 1