def test_close(self):
        conn = self.make_open_connector()
        session = ClientSession(loop=self.loop, connector=conn)

        session.close()
        self.assertIsNone(session.connector)
        self.assertTrue(conn.closed)
 def test_init_cookies_with_list_of_tuples(self):
     session = ClientSession(cookies=[("c1", "cookie1"), ("c2", "cookie2")],
                             loop=self.loop)
     self.assertEqual(set(session.cookies), {'c1', 'c2'})
     self.assertEqual(session.cookies['c1'].value, 'cookie1')
     self.assertEqual(session.cookies['c2'].value, 'cookie2')
     session.close()
    def test_close(self):
        conn = self.make_open_connector()
        session = ClientSession(loop=self.loop, connector=conn)

        session.close()
        self.assertIsNone(session.connector)
        self.assertTrue(conn.closed)
 def test_init_headers_simple_dict(self):
     session = ClientSession(headers={
         "h1": "header1",
         "h2": "header2"
     },
                             loop=self.loop)
     self.assertEqual(sorted(session._default_headers.items()),
                      ([("H1", "header1"), ("H2", "header2")]))
     session.close()
 def test_init_cookies_with_list_of_tuples(self):
     session = ClientSession(
         cookies=[("c1", "cookie1"),
                  ("c2", "cookie2")],
         loop=self.loop)
     self.assertEqual(set(session.cookies), {'c1', 'c2'})
     self.assertEqual(session.cookies['c1'].value, 'cookie1')
     self.assertEqual(session.cookies['c2'].value, 'cookie2')
     session.close()
 def test_init_headers_list_of_tuples(self):
     session = ClientSession(headers=[("h1", "header1"), ("h2", "header2"),
                                      ("h3", "header3")],
                             loop=self.loop)
     self.assertEqual(
         session._default_headers,
         CIMultiDict([("h1", "header1"), ("h2", "header2"),
                      ("h3", "header3")]))
     session.close()
 def test_init_cookies_with_simple_dict(self):
     session = ClientSession(
         cookies={
             "c1": "cookie1",
             "c2": "cookie2"
         }, loop=self.loop)
     self.assertEqual(set(session.cookies), {'c1', 'c2'})
     self.assertEqual(session.cookies['c1'].value, 'cookie1')
     self.assertEqual(session.cookies['c2'].value, 'cookie2')
     session.close()
 def test_init_cookies_with_simple_dict(self):
     session = ClientSession(cookies={
         "c1": "cookie1",
         "c2": "cookie2"
     },
                             loop=self.loop)
     self.assertEqual(set(session.cookies), {'c1', 'c2'})
     self.assertEqual(session.cookies['c1'].value, 'cookie1')
     self.assertEqual(session.cookies['c2'].value, 'cookie2')
     session.close()
 def test_init_headers_MultiDict(self):
     session = ClientSession(headers=MultiDict([("h1", "header1"),
                                                ("h2", "header2"),
                                                ("h3", "header3")]),
                             loop=self.loop)
     self.assertEqual(
         session._default_headers,
         CIMultiDict([("H1", "header1"), ("H2", "header2"),
                      ("H3", "header3")]))
     session.close()
 def test_merge_headers_with_list_of_tuples(self):
     session = ClientSession(headers={
         "h1": "header1",
         "h2": "header2"
     },
                             loop=self.loop)
     headers = session._prepare_headers([("h1", "h1")])
     self.assertIsInstance(headers, CIMultiDict)
     self.assertEqual(headers, CIMultiDict([("h2", "header2"),
                                            ("h1", "h1")]))
     session.close()
 def test_init_headers_simple_dict(self):
     session = ClientSession(
         headers={
             "h1": "header1",
             "h2": "header2"
         }, loop=self.loop)
     self.assertEqual(
         sorted(session._default_headers.items()),
         ([("H1", "header1"),
           ("H2", "header2")]))
     session.close()
 def test_merge_headers(self):
     # Check incoming simple dict
     session = ClientSession(headers={
         "h1": "header1",
         "h2": "header2"
     },
                             loop=self.loop)
     headers = session._prepare_headers({"h1": "h1"})
     self.assertIsInstance(headers, CIMultiDict)
     self.assertEqual(headers, CIMultiDict([("h2", "header2"),
                                            ("h1", "h1")]))
     session.close()
 def test_init_headers_list_of_tuples(self):
     session = ClientSession(
         headers=[("h1", "header1"),
                  ("h2", "header2"),
                  ("h3", "header3")],
         loop=self.loop)
     self.assertEqual(
         session._default_headers,
         CIMultiDict([("h1", "header1"),
                      ("h2", "header2"),
                      ("h3", "header3")]))
     session.close()
 def test_init_headers_MultiDict(self):
     session = ClientSession(
         headers=MultiDict(
             [("h1", "header1"),
              ("h2", "header2"),
              ("h3", "header3")]),
         loop=self.loop)
     self.assertEqual(
         session._default_headers,
         CIMultiDict([("H1", "header1"),
                      ("H2", "header2"),
                      ("H3", "header3")]))
     session.close()
 def test_merge_headers_with_list_of_tuples(self):
     session = ClientSession(
         headers={
             "h1": "header1",
             "h2": "header2"
         }, loop=self.loop)
     headers = session._prepare_headers([("h1", "h1")])
     self.assertIsInstance(headers, CIMultiDict)
     self.assertEqual(headers, CIMultiDict([
         ("h2", "header2"),
         ("h1", "h1")
     ]))
     session.close()
 def test_merge_headers(self):
     # Check incoming simple dict
     session = ClientSession(
         headers={
             "h1": "header1",
             "h2": "header2"
         }, loop=self.loop)
     headers = session._prepare_headers({
         "h1": "h1"
     })
     self.assertIsInstance(headers, CIMultiDict)
     self.assertEqual(headers, CIMultiDict([
         ("h2", "header2"),
         ("h1", "h1")
     ]))
     session.close()
class AIOResponsesRaiseForStatusSessionTestCase(AsyncTestCase):
    """Test case for sessions with raise_for_status=True.

    This flag, introduced in aiohttp v2.0.0, automatically calls
    `raise_for_status()`.
    It is overridden by the `raise_for_status` argument of the request since
    aiohttp v3.4.a0.

    """
    async def setup(self):
        self.url = 'http://example.com/api?foo=bar#fragment'
        self.session = ClientSession(raise_for_status=True)

    async def teardown(self):
        close_result = self.session.close()
        if close_result is not None:
            await close_result

    @aioresponses()
    async def test_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            await self.session.get(self.url)
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    @skipIf(condition=AIOHTTP_VERSION < '3.4.0',
            reason='aiohttp<3.4.0 does not support raise_for_status '
            'arguments for requests')
    async def test_do_not_raise_for_status(self, m):
        m.get(self.url, status=400)
        response = await self.session.get(self.url, raise_for_status=False)

        self.assertEqual(response.status, 400)
class AIOResponseRedirectTest(TestCase):
    @asyncio.coroutine
    def setUp(self):
        self.url = "http://10.1.1.1:8080/redirect"
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        close_result = self.session.close()
        if close_result is not None:
            yield from close_result
        super().tearDown()

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_followed(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        rsps.get("https://httpbin.org")
        response = yield from self.session.get(self.url, allow_redirects=True)
        self.assertEqual(response.status, 200)
        self.assertEqual(str(response.url), "https://httpbin.org")
        self.assertEqual(len(response.history), 1)
        self.assertEqual(str(response.history[0].url), self.url)

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_missing_mocked_match(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        with self.assertRaises(ClientConnectionError) as cm:
            response = yield from self.session.get(self.url,
                                                   allow_redirects=True)
        self.assertEqual(
            str(cm.exception),
            'Connection refused: GET http://10.1.1.1:8080/redirect')

    @aioresponses()
    @asyncio.coroutine
    def test_redirect_missing_location_header(self, rsps):
        rsps.get(self.url, status=307)
        response = yield from self.session.get(self.url, allow_redirects=True)
        self.assertEqual(str(response.url), self.url)
class AIOResponsesTestCase(TestCase):
    use_default_loop = False

    @asyncio.coroutine
    def setUp(self):
        self.url = 'http://example.com/api?foo=bar#fragment'
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        close_result = self.session.close()
        if close_result is not None:
            yield from close_result
        super().tearDown()

    def run_async(self, coroutine: Union[Coroutine, Generator]):
        return self.loop.run_until_complete(coroutine)

    @asyncio.coroutine
    def request(self, url: str):
        return (yield from self.session.get(url))

    @data(
        hdrs.METH_HEAD,
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    @fail_on(unused_loop=False)
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.run_async(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_instance_and_status_code(self, m):
        m.get(self.url, status=204)
        response = yield from self.session.get(self.url)
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 204)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = yield from self.session.get(self.url)

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    @asyncio.coroutine
    def test_returned_response_raw_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = yield from self.session.get(self.url)
        expected_raw_headers = ((b'Content-Type', b'text/html'),
                                (b'Connection', b'keep-alive'))

        self.assertEqual(response.raw_headers, expected_raw_headers)

    @aioresponses()
    @asyncio.coroutine
    def test_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            response = yield from self.session.get(self.url)
            response.raise_for_status()
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    @asyncio.coroutine
    def test_returned_instance_and_params_handling(self, m):
        expected_url = 'http://example.com/api?foo=bar&x=42#fragment'
        m.get(expected_url)
        response = yield from self.session.get(self.url, params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

        expected_url = 'http://example.com/api?x=42#fragment'
        m.get(expected_url)
        response = yield from self.session.get(
            'http://example.com/api#fragment', params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.run_async(self.session.post(self.url))

    @aioresponses()
    @asyncio.coroutine
    def test_streaming(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read()
        self.assertEqual(content, b'Test')

    @aioresponses()
    @asyncio.coroutine
    def test_streaming_up_to(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'Te')
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'st')

    @asyncio.coroutine
    def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            payload = yield from resp.json()
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(loop, m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo(self.loop)

    @asyncio.coroutine
    def test_passing_argument(self):
        @aioresponses(param='mocked')
        @asyncio.coroutine
        def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)

        yield from foo()

    @fail_on(unused_loop=False)
    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    @asyncio.coroutine
    def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                yield from self.session.get('http://example.com/foo')

    @asyncio.coroutine
    def test_raising_custom_error(self):
        with aioresponses() as aiomock:
            aiomock.get(self.url, exception=HttpProcessingError(message='foo'))
            with self.assertRaises(HttpProcessingError):
                yield from self.session.get(self.url)

    @asyncio.coroutine
    def test_multiple_requests(self):
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 201)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 202)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 3)
            self.assertEqual(m.requests[key][0].args, tuple())
            self.assertEqual(m.requests[key][0].kwargs,
                             {'allow_redirects': True})

    @asyncio.coroutine
    def test_address_as_instance_of_url_combined_with_pass_through(self):
        external_api = 'http://httpbin.org/status/201'

        @asyncio.coroutine
        def doit():
            api_resp = yield from self.session.get(self.url)
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = yield from self.session.get(URL(external_api))
            return api_resp, ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            m.get(self.url, status=200)
            api, ext = yield from doit()

            self.assertEqual(api.status, 200)
            self.assertEqual(ext.status, 201)

    @aioresponses()
    @asyncio.coroutine
    def test_custom_response_class(self, m):
        class CustomClientResponse(ClientResponse):
            pass

        m.get(self.url, body='Test', response_class=CustomClientResponse)
        resp = yield from self.session.get(self.url)
        self.assertTrue(isinstance(resp, CustomClientResponse))

    @aioresponses()
    def test_exceptions_in_the_middle_of_responses(self, mocked):
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=200)

        @asyncio.coroutine
        def doit():
            return (yield from self.session.get(self.url))

        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 200)

    @aioresponses()
    @asyncio.coroutine
    def test_request_should_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://example\.com/api\?foo=.*$'),
                   payload={},
                   status=200)

        response = yield from self.request(self.url)
        self.assertEqual(response.status, 200)

    @aioresponses()
    @asyncio.coroutine
    def test_request_does_not_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://exampleexample\.com/api\?foo=.*$'),
                   payload={},
                   status=200)
        with self.assertRaises(ClientConnectionError):
            yield from self.request(self.url)

    @aioresponses()
    def test_timeout(self, mocked):
        mocked.get(self.url, timeout=True)

        with self.assertRaises(asyncio.TimeoutError):
            self.run_async(self.request(self.url))

    @aioresponses()
    def test_callback(self, m):
        body = b'New body'

        def callback(url, **kwargs):
            self.assertEqual(str(url), self.url)
            self.assertEqual(kwargs, {'allow_redirects': True})
            return CallbackResult(body=body)

        m.get(self.url, callback=callback)
        response = self.run_async(self.request(self.url))
        data = self.run_async(response.read())
        assert data == body
 def test_closed(self):
     session = ClientSession(loop=self.loop)
     self.assertFalse(session.closed)
     session.close()
     self.assertTrue(session.closed)
Beispiel #21
0
class ServiceClient:

    def __init__(self, rest_service_name='GenericService', spec=None, plugins=None, config=None,
                 parser=None, serializer=None, base_path='', loop=None, logger=None):
        self._plugins = []

        self.logger = logger or logging.getLogger('serviceClient.{}'.format(rest_service_name))
        self.rest_service_name = rest_service_name
        self.spec = spec or {}
        self.add_plugins(plugins or [])
        self.config = config or {}
        self.parser = parser or (lambda x, *args, **kwargs: x)
        self.serializer = serializer or (lambda x, *args, **kwargs: x)
        self.base_path = base_path
        self.loop = loop or get_event_loop()

        self.connector = TCPConnector(loop=self.loop, **self.config.get('connector', {}))
        self.session = ClientSession(connector=self.connector, loop=self.loop)

    @coroutine
    def call(self, service_name, payload=None, **kwargs):
        self.logger.debug("Calling service_client {0}...".format(service_name))
        service_desc = self.spec[service_name].copy()
        service_desc['service_name'] = service_name

        request_params = kwargs
        session = yield from self.prepare_session(service_desc, request_params)

        request_params['url'] = yield from self.generate_path(service_desc, session, request_params)
        request_params['method'] = service_desc.get('method', 'GET').upper()

        yield from self.prepare_request_params(service_desc, session, request_params)

        self.logger.info("Calling service_client {0} using {1} {2}".format(service_name,
                                                                           request_params['method'],
                                                                           request_params['url']))

        payload = yield from self.prepare_payload(service_desc, session, request_params, payload)
        try:
            if request_params['method'] not in ['GET', 'DELETE']:
                try:
                    stream_request = service_desc['stream_request']
                except KeyError:
                    stream_request = False
                if payload and not stream_request:
                    request_params['data'] = self.serializer(payload, session=session,
                                                             service_desc=service_desc,
                                                             request_params=request_params)

            yield from self.before_request(service_desc, session, request_params)

            response = yield from session.request(**request_params)
        except Exception as e:
            self.logger.warn("Exception calling service_client {0}: {1}".format(service_name, e))
            yield from self.on_exception(service_desc, session, request_params, e)
            raise e

        yield from self.on_response(service_desc, session, request_params, response)

        try:
            if service_desc['stream_response']:
                return response
        except KeyError:
            pass

        try:
            self.logger.info("Parsing response from {0}...".format(service_name))
            response.data = self.parser((yield from response.read()),
                                        session=session,
                                        service_desc=service_desc,
                                        response=response)
            yield from self.on_parsed_response(service_desc, session, request_params, response)
        except Exception as e:
            self.logger.warn("[Response code: {0}] Exception parsing response from service_client "
                             "{1}: {2}".format(response.status, service_name, e))
            yield from self.on_parse_exception(service_desc, session, request_params, response, e)
            e.response = response
            raise e

        return response

    @coroutine
    def prepare_session(self, service_desc, request_params):
        session = SessionWrapper(self.session)
        yield from self._execute_plugin_hooks('prepare_session', service_desc=service_desc, session=session,
                                              request_params=request_params)
        return session

    @coroutine
    def generate_path(self, service_desc, session, request_params):
        path = service_desc.get('path', '')
        url = list(urlparse(self.base_path))
        url[2] = '/'.join([url[2].rstrip('/'), path.lstrip('/')])
        url.pop()
        path = urlunsplit(url)
        hooks = [getattr(plugin, 'prepare_path') for plugin in self._plugins
                 if hasattr(plugin, 'prepare_path')]
        self.logger.debug("Calling {0} plugin hooks...".format('prepare_path'))
        for func in hooks:
            path = yield from func(service_desc=service_desc, session=session,
                                   request_params=request_params, path=path)

        return path

    @coroutine
    def prepare_request_params(self, service_desc, session, request_params):
        yield from self._execute_plugin_hooks('prepare_request_params', service_desc=service_desc,
                                              session=session, request_params=request_params)

    @coroutine
    def prepare_payload(self, service_desc, session, request_params, payload):
        hooks = [getattr(plugin, 'prepare_payload') for plugin in self._plugins
                 if hasattr(plugin, 'prepare_payload')]
        self.logger.debug("Calling {0} plugin hooks...".format('prepare_payload'))
        for func in hooks:
            payload = yield from func(service_desc=service_desc, session=session,
                                      request_params=request_params, payload=payload)
        return payload

    @coroutine
    def before_request(self, service_desc, session, request_params):
        yield from self._execute_plugin_hooks('before_request', service_desc=service_desc,
                                              session=session, request_params=request_params)

    @coroutine
    def on_exception(self, service_desc, session, request_params, ex):
        yield from self._execute_plugin_hooks('on_exception', service_desc=service_desc,
                                              session=session, request_params=request_params, ex=ex)

    @coroutine
    def on_response(self, service_desc, session, request_params, response):
        yield from self._execute_plugin_hooks('on_response', service_desc=service_desc,
                                              session=session, request_params=request_params, response=response)

    @coroutine
    def on_parse_exception(self, service_desc, session, request_params, response, ex):
        yield from self._execute_plugin_hooks('on_parse_exception', service_desc=service_desc,
                                              session=session, request_params=request_params, response=response, ex=ex)

    @coroutine
    def on_parsed_response(self, service_desc, session, request_params, response):
        yield from self._execute_plugin_hooks('on_parsed_response', service_desc=service_desc, session=session,
                                              request_params=request_params, response=response)

    @coroutine
    def _execute_plugin_hooks(self, hook, *args, **kwargs):
        hooks = [getattr(plugin, hook) for plugin in self._plugins if hasattr(plugin, hook)]
        self.logger.debug("Calling {0} plugin hooks...".format(hook))
        for func in hooks:
            yield from func(*args, **kwargs)

    def add_plugins(self, plugins):
        self._plugins.extend(plugins)

        hook = 'assign_service_client'
        hooks = [getattr(plugin, hook) for plugin in self._plugins if hasattr(plugin, hook)]
        self.logger.debug("Calling {0} plugin hooks...".format(hook))
        for func in hooks:
            func(service_client=self)

    def __getattr__(self, item):

        @coroutine
        def wrap(*args, **kwargs):

            return self.call(item, *args, **kwargs)

        return wrap

    def __del__(self):  # pragma: no cover
        self.session.close()
 def go():
     session = ClientSession(loop=self.loop)
     session.close()
     with self.assertRaises(RuntimeError):
         yield from session.request('get', '/')
 def test_borrow_connector_loop(self):
     conn = self.make_open_connector()
     session = ClientSession(connector=conn)
     self.assertIs(session._loop, self.loop)
     session.close()
Beispiel #24
0
def function546(function647, function1293, arg1193):
    function1750 = ClientSession(connector=function647, loop=None)
    try:
        assert function1750._loop, loop
    finally:
        function1750.close()
Beispiel #25
0
class AIOResponsesTestCase(TestCase):
    use_default_loop = False

    @asyncio.coroutine
    def setUp(self):
        self.url = 'http://example.com/api'
        self.session = ClientSession()
        super().setUp()

    @asyncio.coroutine
    def tearDown(self):
        self.session.close()
        super().tearDown()

    @data(
        hdrs.METH_HEAD,
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    @fail_on(unused_loop=False)
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.loop.run_until_complete(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_instance_and_status_code(self, m):
        m.get(self.url, status=204)
        response = yield from self.session.get(self.url)
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 204)

    @aioresponses()
    @asyncio.coroutine
    def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = yield from self.session.get(self.url)

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.loop.run_until_complete(self.session.post(self.url))

    @aioresponses()
    @asyncio.coroutine
    def test_streaming(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read()
        self.assertEqual(content, b'Test')

    @aioresponses()
    @asyncio.coroutine
    def test_streaming_up_to(self, m):
        m.get(self.url, body='Test')
        resp = yield from self.session.get(self.url)
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'Te')
        content = yield from resp.content.read(2)
        self.assertEqual(content, b'st')

    @asyncio.coroutine
    def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            payload = yield from resp.json()
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(loop, m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo(self.loop)

    @asyncio.coroutine
    def test_passing_argument(self):
        @aioresponses(param='mocked')
        @asyncio.coroutine
        def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)

        yield from foo()

    @fail_on(unused_loop=False)
    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    @asyncio.coroutine
    def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                yield from self.session.get('http://example.com/foo')

    @asyncio.coroutine
    def test_raising_custom_error(self):
        with aioresponses() as aiomock:
            aiomock.get(self.url, exception=HttpProcessingError(message='foo'))
            with self.assertRaises(HttpProcessingError):
                yield from self.session.get(self.url)

    @asyncio.coroutine
    def test_multiple_requests(self):
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 201)
            resp = yield from self.session.get(self.url)
            self.assertEqual(resp.status, 202)

            key = ('GET', self.url)
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 3)
            self.assertEqual(m.requests[key][0].args, tuple())
            self.assertEqual(m.requests[key][0].kwargs,
                             {'allow_redirects': True})

    @asyncio.coroutine
    def test_address_as_instance_of_url_combined_with_pass_through(self):
        external_api = 'http://httpbin.org/status/201'

        @asyncio.coroutine
        def doit():
            api_resp = yield from self.session.get(self.url)
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = yield from self.session.get(URL(external_api))
            return api_resp, ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            m.get(self.url, status=200)
            api, ext = yield from doit()

            self.assertEqual(api.status, 200)
            self.assertEqual(ext.status, 201)
 def test_connector(self):
     connector = TCPConnector(loop=self.loop)
     session = ClientSession(connector=connector, loop=self.loop)
     self.assertIs(session.connector, connector)
     session.close()
 def test_closed(self):
     session = ClientSession(loop=self.loop)
     self.assertFalse(session.closed)
     session.close()
     self.assertTrue(session.closed)
class AIOResponseRedirectTest(AsyncTestCase):
    async def setup(self):
        self.url = "http://10.1.1.1:8080/redirect"
        self.session = ClientSession()

    async def teardown(self):
        close_result = self.session.close()
        if close_result is not None:
            await close_result

    @aioresponses()
    async def test_redirect_followed(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        rsps.get("https://httpbin.org")
        response = await self.session.get(self.url, allow_redirects=True)
        self.assertEqual(response.status, 200)
        self.assertEqual(str(response.url), "https://httpbin.org")
        self.assertEqual(len(response.history), 1)
        self.assertEqual(str(response.history[0].url), self.url)

    @aioresponses()
    async def test_post_redirect_followed(self, rsps):
        rsps.post(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        rsps.get("https://httpbin.org")
        response = await self.session.post(self.url, allow_redirects=True)
        self.assertEqual(response.status, 200)
        self.assertEqual(str(response.url), "https://httpbin.org")
        self.assertEqual(response.method, "get")
        self.assertEqual(len(response.history), 1)
        self.assertEqual(str(response.history[0].url), self.url)

    @aioresponses()
    async def test_redirect_missing_mocked_match(self, rsps):
        rsps.get(
            self.url,
            status=307,
            headers={"Location": "https://httpbin.org"},
        )
        with self.assertRaises(ClientConnectionError) as cm:
            await self.session.get(self.url, allow_redirects=True)
        self.assertEqual(
            str(cm.exception),
            'Connection refused: GET http://10.1.1.1:8080/redirect')

    @aioresponses()
    async def test_redirect_missing_location_header(self, rsps):
        rsps.get(self.url, status=307)
        response = await self.session.get(self.url, allow_redirects=True)
        self.assertEqual(str(response.url), self.url)

    @aioresponses()
    @skipIf(condition=AIOHTTP_VERSION < '3.1.0',
            reason='aiohttp<3.1.0 does not add request info on response')
    async def test_request_info(self, rsps):
        rsps.get(self.url, status=200)

        response = await self.session.get(self.url)

        request_info = response.request_info
        assert str(request_info.url) == self.url
        assert request_info.headers == {}

    @aioresponses()
    @skipIf(condition=AIOHTTP_VERSION < '3.1.0',
            reason='aiohttp<3.1.0 does not add request info on response')
    async def test_request_info_with_original_request_headers(self, rsps):
        headers = {"Authorization": "Bearer access-token"}
        rsps.get(self.url, status=200)

        response = await self.session.get(self.url, headers=headers)

        request_info = response.request_info
        assert str(request_info.url) == self.url
        assert request_info.headers == headers
 def test_cookies_are_readonly(self):
     session = ClientSession(loop=self.loop)
     with self.assertRaises(AttributeError):
         session.cookies = 123
     session.close()
 def go():
     session = ClientSession(loop=self.loop)
     session.close()
     with self.assertRaises(RuntimeError):
         yield from session.request('get', '/')
class AIOResponsesTestCase(AsyncTestCase):
    async def setup(self):
        self.url = 'http://example.com/api?foo=bar#fragment'
        self.session = ClientSession()

    async def teardown(self):
        close_result = self.session.close()
        if close_result is not None:
            await close_result

    def run_async(self, coroutine: Union[Coroutine, Generator]):
        return self.loop.run_until_complete(coroutine)

    async def request(self, url: str):
        return await self.session.get(url)

    @data(
        hdrs.METH_HEAD,
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    @fail_on(unused_loop=False)
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.run_async(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    async def test_returned_instance_and_status_code(self, m):
        m.get(self.url, status=204)
        response = await self.session.get(self.url)
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 204)

    @aioresponses()
    async def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = await self.session.get(self.url)

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    async def test_returned_response_cookies(self, m):
        m.get(self.url, headers={'Set-Cookie': 'cookie=value'})
        response = await self.session.get(self.url)

        self.assertEqual(response.cookies['cookie'].value, 'value')

    @aioresponses()
    async def test_returned_response_raw_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = await self.session.get(self.url)
        expected_raw_headers = ((hdrs.CONTENT_TYPE.encode(), b'text/html'),
                                (b'Connection', b'keep-alive'))

        self.assertEqual(response.raw_headers, expected_raw_headers)

    @aioresponses()
    async def test_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            response = await self.session.get(self.url)
            response.raise_for_status()
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    @skipIf(condition=AIOHTTP_VERSION < '3.4.0',
            reason='aiohttp<3.4.0 does not support raise_for_status '
            'arguments for requests')
    async def test_request_raise_for_status(self, m):
        m.get(self.url, status=400)
        with self.assertRaises(ClientResponseError) as cm:
            await self.session.get(self.url, raise_for_status=True)
        self.assertEqual(cm.exception.message, http.RESPONSES[400][0])

    @aioresponses()
    async def test_returned_instance_and_params_handling(self, m):
        expected_url = 'http://example.com/api?foo=bar&x=42#fragment'
        m.get(expected_url)
        response = await self.session.get(self.url, params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

        expected_url = 'http://example.com/api?x=42#fragment'
        m.get(expected_url)
        response = await self.session.get('http://example.com/api#fragment',
                                          params={'x': 42})
        self.assertIsInstance(response, ClientResponse)
        self.assertEqual(response.status, 200)

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.run_async(self.session.post(self.url))

    @aioresponses()
    async def test_streaming(self, m):
        m.get(self.url, body='Test')
        resp = await self.session.get(self.url)
        content = await resp.content.read()
        self.assertEqual(content, b'Test')

    @aioresponses()
    async def test_streaming_up_to(self, m):
        m.get(self.url, body='Test')
        resp = await self.session.get(self.url)
        content = await resp.content.read(2)
        self.assertEqual(content, b'Te')
        content = await resp.content.read(2)
        self.assertEqual(content, b'st')

    async def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = await self.session.get(self.url)
            self.assertEqual(resp.status, 200)
            payload = await resp.json()
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(loop, m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo(self.loop)

    async def test_passing_argument(self):
        @aioresponses(param='mocked')
        async def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = await self.session.get(self.url)
            self.assertEqual(resp.status, 200)

        await foo()

    @fail_on(unused_loop=False)
    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    async def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                await self.session.get('http://example.com/foo')

    async def test_raising_exception(self):
        with aioresponses() as aiomock:
            url = 'http://example.com/Exception'
            aiomock.get(url, exception=Exception)
            with self.assertRaises(Exception):
                await self.session.get(url)

            url = 'http://example.com/Exception_object'
            aiomock.get(url, exception=Exception())
            with self.assertRaises(Exception):
                await self.session.get(url)

            url = 'http://example.com/BaseException'
            aiomock.get(url, exception=BaseException)
            with self.assertRaises(BaseException):
                await self.session.get(url)

            url = 'http://example.com/BaseException_object'
            aiomock.get(url, exception=BaseException())
            with self.assertRaises(BaseException):
                await self.session.get(url)

            url = 'http://example.com/CancelError'
            aiomock.get(url, exception=CancelledError)
            with self.assertRaises(CancelledError):
                await self.session.get(url)

            url = 'http://example.com/TimeoutError'
            aiomock.get(url, exception=TimeoutError)
            with self.assertRaises(TimeoutError):
                await self.session.get(url)

            url = 'http://example.com/HttpProcessingError'
            aiomock.get(url, exception=HttpProcessingError(message='foo'))
            with self.assertRaises(HttpProcessingError):
                await self.session.get(url)

    async def test_multiple_requests(self):
        """Ensure that requests are saved the way they would have been sent."""
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            json_content_as_ref = [1]
            resp = await self.session.get(self.url, json=json_content_as_ref)
            self.assertEqual(resp.status, 200)
            json_content_as_ref[:] = [2]
            resp = await self.session.get(self.url, json=json_content_as_ref)
            self.assertEqual(resp.status, 201)
            json_content_as_ref[:] = [3]
            resp = await self.session.get(self.url, json=json_content_as_ref)
            self.assertEqual(resp.status, 202)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 3)

            first_request = m.requests[key][0]
            self.assertEqual(first_request.args, tuple())
            self.assertEqual(first_request.kwargs, {
                'allow_redirects': True,
                "json": [1]
            })

            second_request = m.requests[key][1]
            self.assertEqual(second_request.args, tuple())
            self.assertEqual(second_request.kwargs, {
                'allow_redirects': True,
                "json": [2]
            })

            third_request = m.requests[key][2]
            self.assertEqual(third_request.args, tuple())
            self.assertEqual(third_request.kwargs, {
                'allow_redirects': True,
                "json": [3]
            })

    async def test_request_with_non_deepcopyable_parameter(self):
        def non_deep_copyable():
            """A generator does not allow deepcopy."""
            for line in ["header1,header2", "v1,v2", "v10,v20"]:
                yield line

        generator_value = non_deep_copyable()

        with aioresponses() as m:
            m.get(self.url, status=200)
            resp = await self.session.get(self.url, data=generator_value)
            self.assertEqual(resp.status, 200)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 1)

            request = m.requests[key][0]
            self.assertEqual(request.args, tuple())
            self.assertEqual(request.kwargs, {
                'allow_redirects': True,
                "data": generator_value
            })

    async def test_request_retrieval_in_case_no_response(self):
        with aioresponses() as m:
            with self.assertRaises(ClientConnectionError):
                await self.session.get(self.url)

            key = ('GET', URL(self.url))
            self.assertIn(key, m.requests)
            self.assertEqual(len(m.requests[key]), 1)
            self.assertEqual(m.requests[key][0].args, tuple())
            self.assertEqual(m.requests[key][0].kwargs,
                             {'allow_redirects': True})

    async def test_request_failure_in_case_session_is_closed(self):
        async def do_request(session):
            return (await session.get(self.url))

        with aioresponses():
            coro = do_request(self.session)
            await self.session.close()

            with self.assertRaises(RuntimeError) as exception_info:
                await coro
            assert str(exception_info.exception) == "Session is closed"

    async def test_address_as_instance_of_url_combined_with_pass_through(self):
        external_api = 'http://httpbin.org/status/201'

        async def doit():
            api_resp = await self.session.get(self.url)
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = await self.session.get(URL(external_api))
            return api_resp, ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            m.get(self.url, status=200)
            api, ext = await doit()

            self.assertEqual(api.status, 200)
            self.assertEqual(ext.status, 201)

    async def test_pass_through_with_origin_params(self):
        external_api = 'http://httpbin.org/get'

        async def doit(params):
            # we have to hit actual url,
            # otherwise we do not test pass through option properly
            ext_rep = await self.session.get(URL(external_api), params=params)
            return ext_rep

        with aioresponses(passthrough=[external_api]) as m:
            params = {'foo': 'bar'}
            ext = await doit(params=params)
            self.assertEqual(ext.status, 200)
            self.assertEqual(str(ext.url), 'http://httpbin.org/get?foo=bar')

    @aioresponses()
    async def test_custom_response_class(self, m):
        class CustomClientResponse(ClientResponse):
            pass

        m.get(self.url, body='Test', response_class=CustomClientResponse)
        resp = await self.session.get(self.url)
        self.assertTrue(isinstance(resp, CustomClientResponse))

    @aioresponses()
    def test_exceptions_in_the_middle_of_responses(self, mocked):
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=204)
        mocked.get(
            self.url,
            exception=ValueError('oops'),
        )
        mocked.get(self.url, payload={}, status=200)

        async def doit():
            return (await self.session.get(self.url))

        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 204)
        with self.assertRaises(ValueError):
            self.run_async(doit())
        self.assertEqual(self.run_async(doit()).status, 200)

    @aioresponses()
    async def test_request_should_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://example\.com/api\?foo=.*$'),
                   payload={},
                   status=200)

        response = await self.request(self.url)
        self.assertEqual(response.status, 200)

    @aioresponses()
    async def test_request_does_not_match_regexp(self, mocked):
        mocked.get(re.compile(r'^http://exampleexample\.com/api\?foo=.*$'),
                   payload={},
                   status=200)
        with self.assertRaises(ClientConnectionError):
            await self.request(self.url)

    @aioresponses()
    def test_timeout(self, mocked):
        mocked.get(self.url, timeout=True)

        with self.assertRaises(asyncio.TimeoutError):
            self.run_async(self.request(self.url))

    @aioresponses()
    def test_callback(self, m):
        body = b'New body'

        def callback(url, **kwargs):
            self.assertEqual(str(url), self.url)
            self.assertEqual(kwargs, {'allow_redirects': True})
            return CallbackResult(body=body)

        m.get(self.url, callback=callback)
        response = self.run_async(self.request(self.url))
        data = self.run_async(response.read())
        assert data == body

    @aioresponses()
    def test_callback_coroutine(self, m):
        body = b'New body'
        event = asyncio.Event()

        async def callback(url, **kwargs):
            await event.wait()
            self.assertEqual(str(url), self.url)
            self.assertEqual(kwargs, {'allow_redirects': True})
            return CallbackResult(body=body)

        m.get(self.url, callback=callback)
        future = asyncio.ensure_future(self.request(self.url))
        self.run_async(asyncio.wait([future], timeout=0))
        assert not future.done()
        event.set()
        self.run_async(asyncio.wait([future], timeout=0))
        assert future.done()
        response = future.result()
        data = self.run_async(response.read())
        assert data == body

    @aioresponses()
    async def test_exception_requests_are_tracked(self, mocked):
        kwargs = {"json": [42], "allow_redirects": True}
        mocked.get(self.url, exception=ValueError('oops'))

        with self.assertRaises(ValueError):
            await self.session.get(self.url, **kwargs)

        key = ('GET', URL(self.url))
        mocked_requests = mocked.requests[key]
        self.assertEqual(len(mocked_requests), 1)

        request = mocked_requests[0]
        self.assertEqual(request.args, ())
        self.assertEqual(request.kwargs, kwargs)

    async def test_possible_race_condition(self):
        async def random_sleep_cb(url, **kwargs):
            await asyncio.sleep(uniform(0.1, 1))
            return CallbackResult(body='test')

        with aioresponses() as mocked:
            for i in range(20):
                mocked.get('http://example.org/id-{}'.format(i),
                           callback=random_sleep_cb)

            tasks = [
                self.session.get('http://example.org/id-{}'.format(i))
                for i in range(20)
            ]
            await asyncio.gather(*tasks)
Beispiel #32
0
def test_create_session_outside_of_coroutine(loop):
    with pytest.warns(ResourceWarning):
        sess = ClientSession(loop=loop)
    sess.close()
Beispiel #33
0
class AIOResponsesTestCase(TestCase):
    def setUp(self):
        self.url = 'http://example.com/api'
        self.loop = asyncio.get_event_loop()
        self.session = ClientSession()

    def tearDown(self):
        self.session.close()

    @data(
        hdrs.METH_GET,
        hdrs.METH_POST,
        hdrs.METH_PUT,
        hdrs.METH_PATCH,
        hdrs.METH_DELETE,
        hdrs.METH_OPTIONS,
    )
    @patch('aioresponses.aioresponses.add')
    def test_shortcut_method(self, http_method, mocked):
        with aioresponses() as m:
            getattr(m, http_method.lower())(self.url)
            mocked.assert_called_once_with(self.url, method=http_method)

    @aioresponses()
    def test_returned_instance(self, m):
        m.get(self.url)
        response = self.loop.run_until_complete(self.session.get(self.url))
        self.assertIsInstance(response, ClientResponse)

    @aioresponses()
    def test_returned_response_headers(self, m):
        m.get(self.url,
              content_type='text/html',
              headers={'Connection': 'keep-alive'})
        response = self.loop.run_until_complete(self.session.get(self.url))

        self.assertEqual(response.headers['Connection'], 'keep-alive')
        self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

    @aioresponses()
    def test_method_dont_match(self, m):
        m.get(self.url)
        with self.assertRaises(ClientConnectionError):
            self.loop.run_until_complete(self.session.post(self.url))

    def test_mocking_as_context_manager(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = self.loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

    def test_mocking_as_decorator(self):
        @aioresponses()
        def foo(m):
            m.add(self.url, payload={'foo': 'bar'})

            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            payload = self.loop.run_until_complete(resp.json())
            self.assertDictEqual(payload, {'foo': 'bar'})

        foo()

    def test_passing_argument(self):
        @aioresponses(param='mocked')
        def foo(mocked):
            mocked.add(self.url, payload={'foo': 'bar'})
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)

        foo()

    def test_mocking_as_decorator_wrong_mocked_arg_name(self):
        @aioresponses(param='foo')
        def foo(bar):
            # no matter what is here it should raise an error
            pass

        with self.assertRaises(TypeError) as cm:
            foo()
        exc = cm.exception
        self.assertIn("foo() got an unexpected keyword argument 'foo'",
                      str(exc))

    def test_unknown_request(self):
        with aioresponses() as aiomock:
            aiomock.add(self.url, payload={'foo': 'bar'})
            with self.assertRaises(ClientConnectionError):
                self.loop.run_until_complete(
                    self.session.get('http://example.com/foo'))

    def test_multiple_requests(self):
        with aioresponses() as m:
            m.get(self.url, status=200)
            m.get(self.url, status=201)
            m.get(self.url, status=202)
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 200)
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 201)
            resp = self.loop.run_until_complete(self.session.get(self.url))
            self.assertEqual(resp.status, 202)
 def test_connector(self):
     connector = TCPConnector(loop=self.loop)
     session = ClientSession(connector=connector, loop=self.loop)
     self.assertIs(session.connector, connector)
     session.close()
Beispiel #35
0
 def test_close(self):
     session = ClientSession(loop=self.loop)
     session._connector = mock.Mock(BaseConnector)
     session.close()
     session._connector.close.assert_called_once_with()
 def test_cookies_are_readonly(self):
     session = ClientSession(loop=self.loop)
     with self.assertRaises(AttributeError):
         session.cookies = 123
     session.close()
def test_borrow_connector_loop(connector, create_session, loop):
    session = ClientSession(connector=connector, loop=None)
    try:
        assert session._loop, loop
    finally:
        session.close()
 def test_borrow_connector_loop(self):
     conn = self.make_open_connector()
     session = ClientSession(connector=conn)
     self.assertIs(session._loop, self.loop)
     session.close()
Beispiel #39
0
def test_borrow_connector_loop(connector, create_session, loop):
    session = ClientSession(connector=connector, loop=None)
    try:
        assert session._loop, loop
    finally:
        session.close()
class LineStickerSpider:
    # default headers for requests
    headers = {
        'User-Agent':
        'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.132 Safari/537.36',
    }
    # url templates
    urlf_page = 'https://store.line.me/stickershop/product/{id}/en'.format
    urlf_search = 'https://store.line.me/api/search/sticker' \
                  '?query={query}&offset=0&limit=36&type=ALL&includeFacets=false'.format
    # sticker resolve order
    sticker_order = ['popup', 'animation', 'static']
    re_sticker_id = re.compile('sticker/([^/]+)')  # sticker id in url
    re_sticker_page_id = re.compile(
        'stickershop/product/([^/]+)')  # sticker shop id in url

    def __init__(self, connection_limit=10):
        connector = TCPConnector(limit=connection_limit)
        self.con = ClientSession(connector=connector, headers=self.headers)

    def __del__(self):
        asyncio.ensure_future(self.con.close())

    async def crawl_search(self, query) -> Dict[str, str]:
        """
        Crawls search and returns title-to-id dictionary, e.g.:
        {"cool sticker #1": '1234233', ...}
        """

        url = self.urlf_search(query=query)
        response = await self.con.get(url)
        data = await response.json()
        results = {}
        for d in data['items']:
            results[d['title']] = d['id']
        return results

    async def crawl_pages(self,
                          page_ids: List[str]) -> Tuple[List[str], List[str]]:
        """Crawl sticker pages for sticker urls"""
        all_results = await gather(*[self.crawl_page(id_) for id_ in page_ids])
        stickers = []
        audio = []
        for results in all_results:
            stickers.extend(results[0])
            audio.extend(results[1])
        return stickers, audio

    async def crawl_page(self, page_id):
        """crawl single sticker page and parse out sticker file urls"""
        response = await self.con.get(self.urlf_page(id=page_id))
        return self.parse_page(await response.text())

    async def dl_files(self, urls: List[str], output: Path) -> None:
        """Download files from urls asynchroniously"""
        return await gather(*[self.dl_file(url, output) for url in urls])

    async def dl_file(self, url: str, output: Path) -> None:
        """Donwload file and save it to output"""
        id_ = self.re_sticker_id.findall(url)[0]
        response = await self.con.get(url)
        ext = os.path.splitext(urlparse(url).path)[1]
        with open(output / f'{id_}{ext}', 'wb') as f:
            f.write(await response.read())

    def parse_page(self, body: str) -> Tuple[List[str], List[str]]:
        """Parse page for sticker urls"""
        sel = Selector(text=body)
        data = sel.css('.FnStickerPreviewItem::attr(data-preview)').extract()
        data = [json.loads(d) for d in data]
        stickers = []
        audio = []
        for d in data:
            sound = d.get(f'soundUrl')
            if sound:
                audio.append(sound)
            for key in self.sticker_order:
                value = d.get(f'{key}Url')
                if value:
                    stickers.append(value)
                    break
        return stickers, audio