def test_response_streaming_error_behavior(http_response):
    # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723
    block_size = 103
    total_response_size = 500
    req_response = requests.Response()
    req_request = requests.Request()

    class FakeStreamWithConnectionError:
        # fake object for urllib3.response.HTTPResponse
        def __init__(self):
            self.total_response_size = 500

        def stream(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            left = total_response_size
            while left > 0:
                if left <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, left)
                left -= len(data)
                yield data

        def read(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            if self.total_response_size > 0:
                if self.total_response_size <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, self.total_response_size)
                self.total_response_size -= len(data)
                return data

        def close(self):
            pass

    s = FakeStreamWithConnectionError()
    req_response.raw = FakeStreamWithConnectionError()

    response = create_transport_response(
        http_response,
        req_request,
        req_response,
        block_size,
    )

    def mock_run(self, *args, **kwargs):
        return PipelineResponse(
            None,
            requests.Response(),
            None,
        )

    transport = RequestsTransport()
    pipeline = Pipeline(transport)
    pipeline.run = mock_run
    downloader = response.stream_download(pipeline, decompress=False)
    with pytest.raises(requests.exceptions.ConnectionError):
        full_response = b"".join(downloader)
    def mock_send(http_request,
                  http_response,
                  method,
                  status,
                  headers=None,
                  body=RESPONSE_BODY):
        if headers is None:
            headers = {}
        response = Response()
        response._content_consumed = True
        response._content = json.dumps(body).encode(
            'ascii') if body is not None else None
        response.request = Request()
        response.request.method = method
        response.request.url = RESOURCE_URL
        response.request.headers = {
            'x-ms-client-request-id': '67f4dd4e-6262-45e1-8bed-5c45cf23b6d9'
        }
        response.status_code = status
        response.headers = headers
        response.headers.update(
            {"content-type": "application/json; charset=utf8"})
        response.reason = "OK"

        if is_rest(http_request):
            request = http_request(
                response.request.method,
                response.request.url,
                headers=response.request.headers,
                content=body,
            )
        else:
            request = CLIENT._request(
                response.request.method,
                response.request.url,
                None,  # params
                response.request.headers,
                body,
                None,  # form_content
                None  # stream_content
            )
        response = create_transport_response(
            http_response,
            request,
            response,
        )
        return PipelineResponse(
            request,
            response,
            None  # context
        )
    def _callback(http_response, headers={}):
        polling = LROBasePolling()

        response = Response()
        response.headers = headers
        response.status_code = 200

        response = create_transport_response(
            http_response,
            None,
            response,
        )
        polling._pipeline_response = PipelineResponse(None, response,
                                                      PipelineContext(None))
        polling._initial_response = polling._pipeline_response
        return polling
    def mock_update(http_request, http_response, url, headers=None):
        response = Response()
        response._content_consumed = True
        response.request = mock.create_autospec(Request)
        response.request.method = 'GET'
        response.headers = headers or {}
        response.headers.update(
            {"content-type": "application/json; charset=utf8"})
        response.reason = "OK"

        if url == ASYNC_URL:
            response.request.url = url
            response.status_code = POLLING_STATUS
            response._content = ASYNC_BODY.encode('ascii')
            response.randomFieldFromPollAsyncOpHeader = None

        elif url == LOCATION_URL:
            response.request.url = url
            response.status_code = POLLING_STATUS
            response._content = LOCATION_BODY.encode('ascii')
            response.randomFieldFromPollLocationHeader = None

        elif url == ERROR:
            raise BadEndpointError("boom")

        elif url == RESOURCE_URL:
            response.request.url = url
            response.status_code = POLLING_STATUS
            response._content = RESOURCE_BODY.encode('ascii')

        else:
            raise Exception('URL does not match')
        request = http_request(
            response.request.method,
            response.request.url,
        )
        response = create_transport_response(
            http_response,
            request,
            response,
        )
        return PipelineResponse(
            request,
            response,
            None  # context
        )
def _create_aiohttp_response(http_response, body_bytes, headers=None):
    class MockAiohttpClientResponse(aiohttp.ClientResponse):
        def __init__(self, body_bytes, headers=None):
            self._body = body_bytes
            self._headers = headers
            self._cache = {}
            self.status = 200
            self.reason = "OK"

    req_response = MockAiohttpClientResponse(body_bytes, headers)

    response = create_transport_response(
        http_response,
        None,  # Don't need a request here
        req_response)
    response._content = body_bytes

    return response
def _create_requests_response(http_response, body_bytes, headers=None):
    # https://github.com/psf/requests/blob/67a7b2e8336951d527e223429672354989384197/requests/adapters.py#L255
    req_response = requests.Response()
    req_response._content = body_bytes
    req_response._content_consumed = True
    req_response.status_code = 200
    req_response.reason = 'OK'
    if headers:
        # req_response.headers is type CaseInsensitiveDict
        req_response.headers.update(headers)
    req_response.encoding = requests.utils.get_encoding_from_headers(
        req_response.headers)

    response = create_transport_response(
        http_response,
        None,  # Don't need a request here
        req_response)

    return response
def test_http_client_response(port, http_request, http_response):
    # Create a core request
    request = http_request("GET", "http://localhost:{}".format(port))

    # Fake a transport based on http.client
    conn = HTTPConnection("localhost", port)
    conn.request("GET", "/get")
    r1 = conn.getresponse()

    response = create_transport_response(http_response, request, r1)
    if is_rest(http_response):
        response.read()

    # Don't assume too much in those assert, since we reach a real server
    assert response.internal_response is r1
    assert response.reason is not None
    assert isinstance(response.status_code, int)
    assert len(response.headers.keys()) != 0
    assert len(response.text()) != 0
    assert "content-type" in response.headers
    assert "Content-Type" in response.headers
예제 #8
0
def test_raw_deserializer(http_request, http_response,
                          requests_transport_response):
    raw_deserializer = ContentDecodePolicy()
    context = PipelineContext(None, stream=False)
    universal_request = http_request('GET', 'http://localhost/')
    request = PipelineRequest(universal_request, context)

    def build_response(body, content_type=None):
        if is_rest(http_response):

            class MockResponse(http_response):
                def __init__(self, body, content_type):
                    super(MockResponse, self).__init__(
                        request=None,
                        internal_response=None,
                        status_code=400,
                        reason="Bad Request",
                        content_type="application/json",
                        headers={},
                        stream_download_generator=None,
                    )
                    self._body = body
                    self.content_type = content_type

                def body(self):
                    return self._body

                def read(self):
                    self._content = self._body
                    return self.content

        else:

            class MockResponse(http_response):
                def __init__(self, body, content_type):
                    super(MockResponse, self).__init__(None, None)
                    self._body = body
                    self.content_type = content_type

                def body(self):
                    return self._body

        return PipelineResponse(request, MockResponse(body, content_type),
                                context)

    response = build_response(b"<groot/>", content_type="application/xml")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result.tag == "groot"

    response = build_response(b"\xef\xbb\xbf<utf8groot/>",
                              content_type="application/xml")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result.tag == "utf8groot"

    # The basic deserializer works with unicode XML
    response = build_response(u'<groot language="français"/>'.encode('utf-8'),
                              content_type="application/xml")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result.attrib["language"] == u"français"

    # Catch some weird situation where content_type is XML, but content is JSON
    response = build_response(b'{"ugly": true}',
                              content_type="application/xml")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result["ugly"] is True

    # Be sure I catch the correct exception if it's neither XML nor JSON
    response = build_response(b'gibberish', content_type="application/xml")
    with pytest.raises(DecodeError) as err:
        raw_deserializer.on_response(request, response)
    assert err.value.response is response.http_response

    response = build_response(b'{{gibberish}}', content_type="application/xml")
    with pytest.raises(DecodeError) as err:
        raw_deserializer.on_response(request, response)
    assert err.value.response is response.http_response

    # Simple JSON
    response = build_response(b'{"success": true}',
                              content_type="application/json")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result["success"] is True

    # Simple JSON with BOM
    response = build_response(b'\xef\xbb\xbf{"success": true}',
                              content_type="application/json")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result["success"] is True

    # Simple JSON with complex content_type
    response = build_response(
        b'{"success": true}',
        content_type="application/vnd.microsoft.appconfig.kv.v1+json")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result["success"] is True

    # Simple JSON with complex content_type, v2
    response = build_response(
        b'{"success": true}',
        content_type="text/vnd.microsoft.appconfig.kv.v1+json")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result["success"] is True

    # For compat, if no content-type, decode JSON
    response = build_response(b'"data"')
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result == "data"

    # Let text/plain let through
    response = build_response(b'I am groot', content_type="text/plain")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result == "I am groot"

    # Let text/plain let through + BOM
    response = build_response(b'\xef\xbb\xbfI am groot',
                              content_type="text/plain")
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result == "I am groot"

    # Try with a mock of requests

    req_response = requests.Response()
    req_response.headers["content-type"] = "application/json"
    req_response._content = b'{"success": true}'
    req_response._content_consumed = True
    response = PipelineResponse(
        None,
        create_transport_response(requests_transport_response, None,
                                  req_response),
        PipelineContext(None, stream=False))

    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result["success"] is True

    # I can enable it per request
    request.context.options['response_encoding'] = 'utf-8'
    response = build_response(b'\xc3\xa9', content_type="text/plain")
    raw_deserializer.on_request(request)
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result == u"é"
    assert response.context["response_encoding"] == "utf-8"
    del request.context['response_encoding']

    # I can enable it globally
    raw_deserializer = ContentDecodePolicy(response_encoding="utf-8")
    response = build_response(b'\xc3\xa9', content_type="text/plain")
    raw_deserializer.on_request(request)
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result == u"é"
    assert response.context["response_encoding"] == "utf-8"
    del request.context['response_encoding']

    # Per request is more important
    request.context.options['response_encoding'] = 'utf-8-sig'
    response = build_response(b'\xc3\xa9', content_type="text/plain")
    raw_deserializer.on_request(request)
    raw_deserializer.on_response(request, response)
    result = response.context["deserialized_data"]
    assert result == u"é"
    assert response.context["response_encoding"] == "utf-8-sig"
    del request.context['response_encoding']