예제 #1
0
def _get_run():
    request_message = _get_request_message(GetRun())
    response_message = GetRun.Response()
    response_message.run.MergeFrom(_get_store().get_run(request_message.run_uuid).to_proto())
    response = Response(mimetype='application/json')
    response.set_data(message_to_json(response_message))
    return response
예제 #2
0
def test_malformed_json_error_response(response_mock):
    with mock.patch("requests.request") as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(MlflowException):
            call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
예제 #3
0
def _get_run():
    request_message = _get_request_message(GetRun(), from_get=True)
    response_message = GetRun.Response()
    response_message.run.MergeFrom(_get_store().get_run(
        request_message.run_uuid).to_proto())
    response = Response(mimetype='application/json')
    response.set_data(
        MessageToJson(response_message, preserving_proto_field_name=True))
    return response
예제 #4
0
def _get_run():
    request_message = _get_request_message(GetRun())
    response_message = GetRun.Response()
    run_id = request_message.run_id or request_message.run_uuid
    run = _get_tracking_store().get_run(run_id)
    run_auth_check(run)
    response_message.run.MergeFrom(run.to_proto())
    response = Response(mimetype="application/json")
    response.set_data(message_to_json(response_message))
    return response
예제 #5
0
def test_well_formed_json_error_response():
    with mock.patch("requests.request") as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        response_mock = mock.MagicMock()
        response_mock.status_code = 400
        response_mock.text = "{}"  # well-formed JSON error response
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(RestException):
            call_endpoint(host_only, "/my/endpoint", "GET", "", response_proto)
예제 #6
0
def test_call_endpoints_raises_exceptions():
    with mock.patch("mlflow.utils.rest_utils.call_endpoint") as mock_call_endpoint:
        response_proto = GetRun.Response()
        mock_call_endpoint.side_effect = [
            RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}),
            RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}),
        ]
        host_only = MlflowHostCreds("http://my-host")
        endpoints = [("/my/endpoint", "POST"), ("/my/endpoint", "GET")]
        with pytest.raises(RestException):
            call_endpoints(host_only, endpoints, "", response_proto)
        mock_call_endpoint.side_effect = [RestException({}), None]
        with pytest.raises(RestException):
            call_endpoints(host_only, endpoints, "", response_proto)
예제 #7
0
def test_non_json_ok_response():
    with mock.patch("requests.Session.request") as request_mock:
        host_only = MlflowHostCreds("http://my-host")
        response_mock = mock.MagicMock()
        response_mock.status_code = 200
        response_mock.text = "<html></html>"
        request_mock.return_value = response_mock

        response_proto = GetRun.Response()
        with pytest.raises(
            MlflowException,
            match="API request to endpoint was successful but the response body was not "
            "in a valid JSON format",
        ):
            call_endpoint(host_only, "/api/2.0/fetch-model", "GET", "", response_proto)
예제 #8
0
def test_call_endpoints():
    with mock.patch(
            "mlflow.utils.rest_utils.call_endpoint") as mock_call_endpoint:
        response_proto = GetRun.Response()
        mock_call_endpoint.side_effect = [
            RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}),
            None,
        ]
        host_only = MlflowHostCreds("http://my-host")
        endpoints = [("/my/endpoint", "POST"), ("/my/endpoint", "GET")]
        resp = call_endpoints(host_only, endpoints, "", response_proto)
        mock_call_endpoint.assert_has_calls([
            mock.call(host_only, endpoint, method, "", response_proto)
            for endpoint, method in endpoints
        ])
        assert resp is None