Example #1
0
    def test_https_connect(self, ClientRequestMock):
        loop_mock = unittest.mock.Mock()
        proxy_req = ClientRequest('GET', 'http://proxy.example.com',
                                  loop=loop_mock)
        ClientRequestMock.return_value = proxy_req

        proxy_resp = ClientResponse('get', 'http://proxy.example.com')
        proxy_resp._loop = loop_mock
        proxy_req.send = send_mock = unittest.mock.Mock()
        send_mock.return_value = proxy_resp
        proxy_resp.start = start_mock = unittest.mock.Mock()
        self._fake_coroutine(start_mock, unittest.mock.Mock(status=200))

        connector = aiohttp.ProxyConnector(
            'http://proxy.example.com', loop=loop_mock)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        self._fake_coroutine(loop_mock.create_connection, (tr, proto))

        req = ClientRequest('GET', 'https://www.python.org', loop=self.loop)
        self.loop.run_until_complete(connector._create_connection(req))

        self.assertEqual(req.path, '/')
        self.assertEqual(proxy_req.method, 'CONNECT')
        self.assertEqual(proxy_req.path, 'www.python.org:443')
        tr.pause_reading.assert_called_once_with()
        tr.get_extra_info.assert_called_with('socket', default=None)

        proxy_req.close()
        proxy_resp.close()
        req.close()
Example #2
0
    def test_https_connect(self, ClientRequestMock):
        loop_mock = unittest.mock.Mock()
        proxy_req = ClientRequest('GET', 'http://proxy.example.com',
                                  loop=loop_mock)
        ClientRequestMock.return_value = proxy_req

        proxy_resp = ClientResponse('get', 'http://proxy.example.com')
        proxy_resp._loop = loop_mock
        proxy_req.send = send_mock = unittest.mock.Mock()
        send_mock.return_value = proxy_resp
        proxy_resp.start = start_mock = unittest.mock.Mock()
        self._fake_coroutine(start_mock, unittest.mock.Mock(status=200))

        connector = aiohttp.ProxyConnector(
            'http://proxy.example.com', loop=loop_mock)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        self._fake_coroutine(loop_mock.create_connection, (tr, proto))

        req = ClientRequest('GET', 'https://www.python.org', loop=self.loop)
        self.loop.run_until_complete(connector._create_connection(req))

        self.assertEqual(req.path, '/')
        self.assertEqual(proxy_req.method, 'CONNECT')
        self.assertEqual(proxy_req.path, 'www.python.org:443')
        tr.pause_reading.assert_called_once_with()
        tr.get_extra_info.assert_called_with('socket', default=None)

        proxy_req.close()
        proxy_resp.close()
        req.close()
Example #3
0
    def test_https_connect_http_proxy_error(self, ClientRequestMock):
        loop_mock = unittest.mock.Mock()
        proxy_req = ClientRequest('GET', 'http://proxy.example.com',
                                  loop=loop_mock)
        ClientRequestMock.return_value = proxy_req

        proxy_resp = ClientResponse('get', 'http://proxy.example.com')
        proxy_resp._loop = loop_mock
        proxy_req.send = send_mock = unittest.mock.Mock()
        send_mock.return_value = proxy_resp
        proxy_resp.start = start_mock = unittest.mock.Mock()
        self._fake_coroutine(
            start_mock, unittest.mock.Mock(status=400, reason='bad request'))

        connector = aiohttp.ProxyConnector(
            'http://proxy.example.com', loop=loop_mock)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        tr.get_extra_info.return_value = None
        self._fake_coroutine(loop_mock.create_connection, (tr, proto))

        req = ClientRequest('GET', 'https://www.python.org', loop=self.loop)
        with self.assertRaisesRegex(
                aiohttp.HttpProxyError, "400, message='bad request'"):
            self.loop.run_until_complete(connector._create_connection(req))

        proxy_req.close()
        proxy_resp.close()
        req.close()
Example #4
0
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.connection = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser(loop=self.loop)
        self.response = ClientResponse('get', 'http://python.org')
Example #5
0
 def test_wait_for_100(self):
     response = ClientResponse(
         'get', 'http://python.org', continue100=object())
     self.assertTrue(response.waiting_for_continue())
     response = ClientResponse(
         'get', 'http://python.org')
     self.assertFalse(response.waiting_for_continue())
Example #6
0
    def test_https_connect_runtime_error(self, ClientRequestMock):
        loop_mock = unittest.mock.Mock()
        proxy_req = ClientRequest('GET', 'http://proxy.example.com',
                                  loop=loop_mock)
        ClientRequestMock.return_value = proxy_req

        proxy_resp = ClientResponse('get', 'http://proxy.example.com')
        proxy_resp._loop = loop_mock
        proxy_req.send = send_mock = unittest.mock.Mock()
        send_mock.return_value = proxy_resp
        proxy_resp.start = start_mock = unittest.mock.Mock()
        self._fake_coroutine(start_mock, unittest.mock.Mock(status=200))

        connector = aiohttp.ProxyConnector(
            'http://proxy.example.com', loop=loop_mock)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        tr.get_extra_info.return_value = None
        self._fake_coroutine(loop_mock.create_connection, (tr, proto))

        req = ClientRequest('GET', 'https://www.python.org', loop=self.loop)
        with self.assertRaisesRegex(
                RuntimeError, "Transport does not expose socket instance"):
            self.loop.run_until_complete(connector._create_connection(req))

        proxy_req.close()
        proxy_resp.close()
        req.close()
Example #7
0
    def test_https_connect_http_proxy_error(self, ClientRequestMock):
        loop_mock = unittest.mock.Mock()
        proxy_req = ClientRequest('GET', 'http://proxy.example.com',
                                  loop=loop_mock)
        ClientRequestMock.return_value = proxy_req

        proxy_resp = ClientResponse('get', 'http://proxy.example.com')
        proxy_resp._loop = loop_mock
        proxy_req.send = send_mock = unittest.mock.Mock()
        send_mock.return_value = proxy_resp
        proxy_resp.start = start_mock = unittest.mock.Mock()
        self._fake_coroutine(
            start_mock, unittest.mock.Mock(status=400, reason='bad request'))

        connector = aiohttp.ProxyConnector(
            'http://proxy.example.com', loop=loop_mock)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        tr.get_extra_info.return_value = None
        self._fake_coroutine(loop_mock.create_connection, (tr, proto))

        req = ClientRequest('GET', 'https://www.python.org', loop=self.loop)
        with self.assertRaisesRegex(
                aiohttp.HttpProxyError, "400, message='bad request'"):
            self.loop.run_until_complete(connector._create_connection(req))

        proxy_req.close()
        proxy_resp.close()
        req.close()
Example #8
0
    def test_https_connect_runtime_error(self, ClientRequestMock):
        loop_mock = unittest.mock.Mock()
        proxy_req = ClientRequest('GET', 'http://proxy.example.com',
                                  loop=loop_mock)
        ClientRequestMock.return_value = proxy_req

        proxy_resp = ClientResponse('get', 'http://proxy.example.com')
        proxy_resp._loop = loop_mock
        proxy_req.send = send_mock = unittest.mock.Mock()
        send_mock.return_value = proxy_resp
        proxy_resp.start = start_mock = unittest.mock.Mock()
        self._fake_coroutine(start_mock, unittest.mock.Mock(status=200))

        connector = aiohttp.ProxyConnector(
            'http://proxy.example.com', loop=loop_mock)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        tr.get_extra_info.return_value = None
        self._fake_coroutine(loop_mock.create_connection, (tr, proto))

        req = ClientRequest('GET', 'https://www.python.org', loop=self.loop)
        with self.assertRaisesRegex(
                RuntimeError, "Transport does not expose socket instance"):
            self.loop.run_until_complete(connector._create_connection(req))

        proxy_req.close()
        proxy_resp.close()
        req.close()
Example #9
0
    def test_del(self):
        response = ClientResponse('get', 'http://python.org')

        connection = unittest.mock.Mock()
        response._setup_connection(connection)
        del response

        connection.close.assert_called_with()
Example #10
0
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.transport = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser()
        self.response = ClientResponse('get', 'http://base-conn.org')
        self.response._post_init(self.loop)
Example #11
0
    def test_del(self):
        response = ClientResponse('get', 'http://python.org')

        connection = unittest.mock.Mock()
        response._setup_connection(connection)
        with self.assertWarns(ResourceWarning):
            del response

        connection.close.assert_called_with()
Example #12
0
async def raise_for_error(response: ClientResponse):
    if response.status < 300:
        return
    try:
        value = await response.json()
        raise QiniuError(code=value['code'], error=value['error'])
    except QiniuError:
        raise
    except (KeyError, TypeError, ContentTypeError):
        response.raise_for_status()
Example #13
0
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.connection = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser(loop=self.loop)
        self.response = ClientResponse('get', 'http://python.org')
Example #14
0
        def cb():
            def read():
                fut = asyncio.Future(loop=self._loop)
                with open(filename) as f:
                    fut.set_result(f.read().encode("utf-8"))
                return fut

            resp = ClientResponse("GET", uri)
            resp.headers = {"Content-Type": "application/json"}

            resp.status = 200
            resp.content = Mock()
            resp.content.read.side_effect = read
            resp.close = Mock()
            fut = asyncio.Future(loop=self._loop)
            fut.set_result(resp)
            return fut
Example #15
0
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.transport = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser()
        self.response = ClientResponse('get', 'http://base-conn.org')
        self.response._post_init(self.loop)
Example #16
0
        def cb():
            def read():
                fut = asyncio.Future(loop=self._loop)
                with open(filename) as f:
                    fut.set_result(f.read().encode('utf-8'))
                return fut

            resp = ClientResponse('GET',
                                  yarl.URL(uri),
                                  writer=Mock(),
                                  timer=TimerNoop(),
                                  continue100=None,
                                  request_info=Mock(),
                                  traces=[],
                                  loop=self._loop,
                                  session=Mock())
            resp._headers = {'Content-Type': 'application/json'}

            resp.status = status
            resp.reason = Mock()
            resp.content = Mock()
            resp.content.read.side_effect = read
            resp.close = Mock()
            fut = asyncio.Future(loop=self._loop)
            fut.set_result(resp)
            return fut
Example #17
0
        def cb():
            fut = asyncio.Future(loop=self._loop)
            resp = ClientResponse('GET',
                                  yarl.URL('foo'),
                                  writer=Mock(),
                                  timer=TimerNoop(),
                                  continue100=None,
                                  request_info=Mock(),
                                  traces=[],
                                  loop=self._loop,
                                  session=Mock())
            resp.status = status

            # setting this as aiohttp 3.5.4 is now checking if this value is not None
            # see aiohttp/client_reqrep.py:934
            resp.reason = http.client.responses[status]

            fut.set_result(resp)
            return fut
Example #18
0
 def test_wait_for_100(self):
     response = ClientResponse('get',
                               'http://python.org',
                               continue100=object())
     self.assertTrue(response.waiting_for_continue())
     response = ClientResponse('get', 'http://python.org')
     self.assertFalse(response.waiting_for_continue())
Example #19
0
        def cb():
            def read():
                fut = asyncio.Future(loop=self._loop)
                with open(filename) as f:
                    fut.set_result(f.read().encode('utf-8'))
                return fut

            resp = ClientResponse('GET', yarl.URL(uri))
            resp.headers = {'Content-Type': 'application/json'}

            resp.status = status
            resp.content = Mock()
            resp.content.read.side_effect = read
            resp.close = Mock()
            fut = asyncio.Future(loop=self._loop)
            fut.set_result(resp)
            return fut
Example #20
0
class TestBaseConnector(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.transport = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser()
        self.response = ClientResponse('get', 'http://base-conn.org')
        self.response._post_init(self.loop)

    def tearDown(self):
        self.response.close()
        self.loop.close()
        gc.collect()

    def test_del(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]
        conns_impl = conn._conns

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'connections': unittest.mock.ANY,
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    def test_del_with_scheduled_cleanup(self):
        conn = aiohttp.BaseConnector(loop=self.loop, keepalive_timeout=0.01)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            yield from asyncio.sleep(0.01)
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    def test_del_with_closed_loop(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)
        self.loop.close()

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        self.assertFalse(transp.close.called)
        self.assertTrue(exc_handler.called)

    def test_del_empty_conector(self):
        conn = aiohttp.BaseConnector(loop=self.loop)

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        del conn

        self.assertFalse(exc_handler.called)

    def test_create_conn(self):

        def go():
            conn = aiohttp.BaseConnector(loop=self.loop)
            with self.assertRaises(NotImplementedError):
                yield from conn._create_connection(object())

        self.loop.run_until_complete(go())

    @unittest.mock.patch('aiohttp.connector.asyncio')
    def test_ctor_loop(self, asyncio):
        session = aiohttp.BaseConnector()
        self.assertIs(session._loop, asyncio.get_event_loop.return_value)

    def test_close(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertFalse(conn.closed)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

    def test_get(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time())]
        self.assertEqual(conn._get(1), (tr, proto))
        conn.close()

    def test_get_expired(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time() - 1000)]
        self.assertEqual(conn._get(1), (None, None))
        self.assertFalse(conn._conns)
        conn.close()

    def test_release(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task = unittest.mock.Mock()
        req = unittest.mock.Mock()
        resp = req.response = unittest.mock.Mock()
        resp._should_close = False

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[1][0], (tr, proto, 10))
        self.assertTrue(conn._start_cleanup_task.called)
        conn.close()

    def test_release_close(self):
        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        cookies = resp.cookies = http.cookies.SimpleCookie()
        cookies['c1'] = 'cookie1'
        cookies['c2'] = 'cookie2'

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)

    def test_get_pop_empty_conns(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('127.0.0.1', 80, False)
        conn._conns[key] = []
        tr, proto = conn._get(key)
        self.assertEqual((None, None), (tr, proto))
        self.assertFalse(conn._conns)

    def test_release_close_do_not_add_to_pool(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        key = ('127.0.0.1', 80, False)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)

    def test_release_close_do_not_delete_existing_connections(self):
        key = ('127.0.0.1', 80, False)
        tr1, proto1 = unittest.mock.Mock(), unittest.mock.Mock()

        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        conn._conns[key] = [(tr1, proto1, 1)]
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr1)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[key], [(tr1, proto1, 1)])
        self.assertTrue(tr.close.called)
        conn.close()

    def test_release_not_started(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns, {1: [(tr, proto, 10)]})
        self.assertFalse(tr.close.called)
        conn.close()

    def test_release_not_opened(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = unittest.mock.Mock()
        req.response.message = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertTrue(tr.close.called)

    def test_connect(self):
        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        proto.is_connected.return_value = True

        class Req:
            host = 'host'
            port = 80
            ssl = False
            response = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('host', 80, False)
        conn._conns[key] = [(tr, proto, self.loop.time())]
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = helpers.create_future(self.loop)
        conn._create_connection.return_value.set_result((tr, proto))

        connection = self.loop.run_until_complete(conn.connect(Req()))
        self.assertFalse(conn._create_connection.called)
        self.assertEqual(connection._transport, tr)
        self.assertEqual(connection._protocol, proto)
        self.assertIsInstance(connection, Connection)
        connection.close()

    def test_connect_timeout(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = helpers.create_future(self.loop)
        conn._create_connection.return_value.set_exception(
            asyncio.TimeoutError())

        with self.assertRaises(aiohttp.ClientTimeoutError):
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))

    def test_connect_oserr(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = helpers.create_future(self.loop)
        err = OSError(1, 'permission error')
        conn._create_connection.return_value.set_exception(err)

        with self.assertRaises(aiohttp.ClientOSError) as ctx:
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))
        self.assertEqual(1, ctx.exception.errno)
        self.assertTrue(ctx.exception.strerror.startswith('Cannot connect to'))
        self.assertTrue(ctx.exception.strerror.endswith('[permission error]'))

    def test_start_cleanup_task(self):
        loop = unittest.mock.Mock()
        loop.time.return_value = 1.5
        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        self.assertIsNone(conn._cleanup_handle)

        conn._start_cleanup_task()
        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            12, conn._cleanup)

    def test_cleanup(self):
        testset = {
            1: [(unittest.mock.Mock(), unittest.mock.Mock(), 10),
                (unittest.mock.Mock(), unittest.mock.Mock(), 300),
                (None, unittest.mock.Mock(), 300)],
        }
        testset[1][0][1].is_connected.return_value = True
        testset[1][1][1].is_connected.return_value = False

        loop = unittest.mock.Mock()
        loop.time.return_value = 300
        conn = aiohttp.BaseConnector(loop=loop)
        conn._conns = testset
        existing_handle = conn._cleanup_handle = unittest.mock.Mock()

        conn._cleanup()
        self.assertTrue(existing_handle.cancel.called)
        self.assertEqual(conn._conns, {})
        self.assertIsNone(conn._cleanup_handle)

    def test_cleanup2(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 300)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 300.1

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset
        conn._cleanup()
        self.assertEqual(conn._conns, testset)

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            310, conn._cleanup)
        conn.close()

    def test_cleanup3(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 290.1),
                       (unittest.mock.Mock(), unittest.mock.Mock(), 305.1)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 308.5

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset

        conn._cleanup()
        self.assertEqual(conn._conns, {1: [testset[1][1]]})

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            316, conn._cleanup)
        conn.close()

    def test_tcp_connector_ctor(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        self.assertTrue(conn.verify_ssl)
        self.assertIs(conn.fingerprint, None)

        with self.assertWarns(DeprecationWarning):
            self.assertFalse(conn.resolve)
        self.assertFalse(conn.use_dns_cache)

        self.assertEqual(conn.family, 0)

        with self.assertWarns(DeprecationWarning):
            self.assertEqual(conn.resolved_hosts, {})
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_ctor_fingerprint_valid(self):
        valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
        conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid)
        self.assertEqual(conn.fingerprint, valid)

    def test_tcp_connector_fingerprint_invalid(self):
        invalid = b'\x00'
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid)

    def test_tcp_connector_clear_resolved_hosts(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        with self.assertWarns(DeprecationWarning):
            conn.clear_resolved_hosts()
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_clear_dns_cache(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache()
        self.assertEqual(conn.cached_hosts, {})

    def test_tcp_connector_clear_dns_cache_bad_args(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        with self.assertRaises(ValueError):
            conn.clear_dns_cache('localhost')

    def test_ambigous_verify_ssl_and_ssl_context(self):
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(
                verify_ssl=False,
                ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23),
                loop=self.loop)

    def test_dont_recreate_ssl_context(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        ctx = conn.ssl_context
        self.assertIs(ctx, conn.ssl_context)

    def test_respect_precreated_ssl_context(self):
        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        conn = aiohttp.TCPConnector(loop=self.loop, ssl_context=ctx)
        self.assertIs(ctx, conn.ssl_context)

    def test_close_twice(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

        conn._conns = 'Invalid'  # fill with garbage
        conn.close()
        self.assertTrue(conn.closed)

    def test_close_cancels_cleanup_handle(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task()

        self.assertIsNotNone(conn._cleanup_handle)
        conn.close()
        self.assertIsNone(conn._cleanup_handle)

    def test_ctor_with_default_loop(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        self.addCleanup(loop.close)
        self.addCleanup(asyncio.set_event_loop, None)
        conn = aiohttp.BaseConnector()
        self.assertIs(loop, conn._loop)

    def test_connect_with_limit(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = helpers.create_future(
                self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection1 = yield from conn.connect(Req())
            self.assertEqual(connection1._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            acquired = False

            @asyncio.coroutine
            def f():
                nonlocal acquired
                connection2 = yield from conn.connect(Req())
                acquired = True
                self.assertEqual(1, len(conn._acquired[key]))
                connection2.release()

            task = asyncio.async(f(), loop=self.loop)

            yield from asyncio.sleep(0.01, loop=self.loop)
            self.assertFalse(acquired)
            connection1.release()
            yield from asyncio.sleep(0, loop=self.loop)
            self.assertTrue(acquired)
            yield from task
            conn.close()

        self.loop.run_until_complete(go())

    def test_connect_with_limit_cancelled(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = helpers.create_future(
                self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection = yield from conn.connect(Req())
            self.assertEqual(connection._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            with self.assertRaises(asyncio.TimeoutError):
                # limit exhausted
                yield from asyncio.wait_for(conn.connect(Req), 0.01,
                                            loop=self.loop)
            connection.close()
        self.loop.run_until_complete(go())

    def test_connect_with_limit_release_waiters(self):

        def check_with_exc(err):
            conn = aiohttp.BaseConnector(limit=1, loop=self.loop)
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = \
                helpers.create_future(self.loop)
            conn._create_connection.return_value.set_exception(err)

            with self.assertRaises(Exception):
                req = unittest.mock.Mock()
                self.loop.run_until_complete(conn.connect(req))
            key = (req.host, req.port, req.ssl)
            self.assertFalse(conn._waiters[key])

        check_with_exc(OSError(1, 'permission error'))
        check_with_exc(RuntimeError())
        check_with_exc(asyncio.TimeoutError())

    def test_connect_with_limit_concurrent(self):

        @asyncio.coroutine
        def go():
            proto = unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock(_should_close=False)

            max_connections = 2
            num_connections = 0

            conn = aiohttp.BaseConnector(limit=max_connections, loop=self.loop)

            # Use a real coroutine for _create_connection; a mock would mask
            # problems that only happen when the method yields.

            @asyncio.coroutine
            def create_connection(req):
                nonlocal num_connections
                num_connections += 1
                yield from asyncio.sleep(0, loop=self.loop)

                # Make a new transport mock each time because acquired
                # transports are stored in a set. Reusing the same object
                # messes with the count.
                tr = unittest.mock.Mock()

                return tr, proto

            conn._create_connection = create_connection

            # Simulate something like a crawler. It opens a connection, does
            # something with it, closes it, then creates tasks that make more
            # connections and waits for them to finish. The crawler is started
            # with multiple concurrent requests and stops when it hits a
            # predefined maximum number of requests.

            max_requests = 10
            num_requests = 0
            start_requests = max_connections + 1

            @asyncio.coroutine
            def f(start=True):
                nonlocal num_requests
                if num_requests == max_requests:
                    return
                num_requests += 1
                if not start:
                    connection = yield from conn.connect(Req())
                    yield from asyncio.sleep(0, loop=self.loop)
                    connection.release()
                tasks = [
                    asyncio.async(f(start=False), loop=self.loop)
                    for i in range(start_requests)
                ]
                yield from asyncio.wait(tasks, loop=self.loop)

            yield from f()
            conn.close()

            self.assertEqual(max_connections, num_connections)
Example #21
0
class TestBaseConnector(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.transport = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser()
        self.response = ClientResponse('get', 'http://base-conn.org')
        self.response._post_init(self.loop)

    def tearDown(self):
        self.response.close()
        self.loop.close()
        gc.collect()

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]
        conns_impl = conn._conns

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del_with_scheduled_cleanup(self):
        conn = aiohttp.BaseConnector(loop=self.loop, keepalive_timeout=0.01)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            yield from asyncio.sleep(0.01)
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del_with_closed_loop(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)
        self.loop.close()

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        self.assertFalse(transp.close.called)
        self.assertTrue(exc_handler.called)

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del_empty_conector(self):
        conn = aiohttp.BaseConnector(loop=self.loop)

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        del conn

        self.assertFalse(exc_handler.called)

    def test_create_conn(self):

        def go():
            conn = aiohttp.BaseConnector(loop=self.loop)
            with self.assertRaises(NotImplementedError):
                yield from conn._create_connection(object())

        self.loop.run_until_complete(go())

    @unittest.mock.patch('aiohttp.connector.asyncio')
    def test_ctor_loop(self, asyncio):
        session = aiohttp.BaseConnector()
        self.assertIs(session._loop, asyncio.get_event_loop.return_value)

    def test_close(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertFalse(conn.closed)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

    def test_get(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time())]
        self.assertEqual(conn._get(1), (tr, proto))
        conn.close()

    def test_get_expired(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time() - 1000)]
        self.assertEqual(conn._get(1), (None, None))
        self.assertFalse(conn._conns)
        conn.close()

    def test_release(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task = unittest.mock.Mock()
        req = unittest.mock.Mock()
        resp = req.response = unittest.mock.Mock()
        resp._should_close = False

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[1][0], (tr, proto, 10))
        self.assertTrue(conn._start_cleanup_task.called)
        conn.close()

    def test_release_close(self):
        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        cookies = resp.cookies = http.cookies.SimpleCookie()
        cookies['c1'] = 'cookie1'
        cookies['c2'] = 'cookie2'

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)

    def test_get_pop_empty_conns(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('127.0.0.1', 80, False)
        conn._conns[key] = []
        tr, proto = conn._get(key)
        self.assertEqual((None, None), (tr, proto))
        self.assertFalse(conn._conns)

    def test_release_close_do_not_add_to_pool(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        key = ('127.0.0.1', 80, False)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)

    def test_release_close_do_not_delete_existing_connections(self):
        key = ('127.0.0.1', 80, False)
        tr1, proto1 = unittest.mock.Mock(), unittest.mock.Mock()

        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        conn._conns[key] = [(tr1, proto1, 1)]
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr1)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[key], [(tr1, proto1, 1)])
        self.assertTrue(tr.close.called)
        conn.close()

    def test_release_not_started(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns, {1: [(tr, proto, 10)]})
        self.assertFalse(tr.close.called)
        conn.close()

    def test_release_not_opened(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = unittest.mock.Mock()
        req.response.message = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertTrue(tr.close.called)

    def test_connect(self):
        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        proto.is_connected.return_value = True

        class Req:
            host = 'host'
            port = 80
            ssl = False
            response = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('host', 80, False)
        conn._conns[key] = [(tr, proto, self.loop.time())]
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_result((tr, proto))

        connection = self.loop.run_until_complete(conn.connect(Req()))
        self.assertFalse(conn._create_connection.called)
        self.assertEqual(connection._transport, tr)
        self.assertEqual(connection._protocol, proto)
        self.assertIsInstance(connection, Connection)
        connection.close()

    def test_connect_timeout(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_exception(
            asyncio.TimeoutError())

        with self.assertRaises(aiohttp.ClientTimeoutError):
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))

    def test_connect_oserr(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_exception(OSError())

        with self.assertRaises(aiohttp.ClientOSError):
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))

    def test_start_cleanup_task(self):
        loop = unittest.mock.Mock()
        loop.time.return_value = 1.5
        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        self.assertIsNone(conn._cleanup_handle)

        conn._start_cleanup_task()
        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            12, conn._cleanup)

    def test_cleanup(self):
        testset = {
            1: [(unittest.mock.Mock(), unittest.mock.Mock(), 10),
                (unittest.mock.Mock(), unittest.mock.Mock(), 300),
                (None, unittest.mock.Mock(), 300)],
        }
        testset[1][0][1].is_connected.return_value = True
        testset[1][1][1].is_connected.return_value = False

        loop = unittest.mock.Mock()
        loop.time.return_value = 300
        conn = aiohttp.BaseConnector(loop=loop)
        conn._conns = testset
        existing_handle = conn._cleanup_handle = unittest.mock.Mock()

        conn._cleanup()
        self.assertTrue(existing_handle.cancel.called)
        self.assertEqual(conn._conns, {})
        self.assertIsNone(conn._cleanup_handle)

    def test_cleanup2(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 300)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 300.1

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset
        conn._cleanup()
        self.assertEqual(conn._conns, testset)

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            310, conn._cleanup)
        conn.close()

    def test_cleanup3(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 290.1),
                       (unittest.mock.Mock(), unittest.mock.Mock(), 305.1)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 308.5

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset

        conn._cleanup()
        self.assertEqual(conn._conns, {1: [testset[1][1]]})

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            316, conn._cleanup)
        conn.close()

    def test_tcp_connector_ctor(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        self.assertTrue(conn.verify_ssl)
        self.assertIs(conn.fingerprint, None)

        with self.assertWarns(DeprecationWarning):
            self.assertFalse(conn.resolve)
        self.assertFalse(conn.use_dns_cache)

        self.assertEqual(conn.family, socket.AF_INET)

        with self.assertWarns(DeprecationWarning):
            self.assertEqual(conn.resolved_hosts, {})
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_ctor_fingerprint_valid(self):
        valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
        conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid)
        self.assertEqual(conn.fingerprint, valid)

    def test_tcp_connector_fingerprint_invalid(self):
        invalid = b'\x00'
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid)

    def test_tcp_connector_fingerprint(self):
        # The even-index fingerprints below are "expect success" cases
        # for ./sample.crt.der, the cert presented by test_utils.run_server.
        # The odd-index fingerprints are "expect fail" cases.
        testcases = (
            # md5
            b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=',
            b'\x00' * 16,

            # sha1
            b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9',
            b'\x00' * 20,

            # sha256
            b'0\x9a\xc9D\x83\xdc\x91\'\x88\x91\x11\xa1d\x97\xfd\xcb~7U\x14D@L'
            b'\x11\xab\x99\xa8\xae\xb7\x14\xee\x8b',
            b'\x00' * 32,
        )
        for i, fingerprint in enumerate(testcases):
            expect_fail = i % 2
            conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False,
                                        fingerprint=fingerprint)
            with test_utils.run_server(self.loop, use_ssl=True) as httpd:
                coro = client.request('get', httpd.url('method', 'get'),
                                      connector=conn, loop=self.loop)
                if expect_fail:
                    with self.assertRaises(FingerprintMismatch) as cm:
                        self.loop.run_until_complete(coro)
                    exc = cm.exception
                    self.assertEqual(exc.expected, fingerprint)
                    # the previous test case should be what we actually got
                    self.assertEqual(exc.got, testcases[i-1])
                else:
                    # should not raise
                    resp = self.loop.run_until_complete(coro)
                    resp.close(force=True)

            conn.close()

    def test_tcp_connector_clear_resolved_hosts(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        with self.assertWarns(DeprecationWarning):
            conn.clear_resolved_hosts()
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_clear_dns_cache(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache()
        self.assertEqual(conn.cached_hosts, {})

    def test_tcp_connector_clear_dns_cache_bad_args(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        with self.assertRaises(ValueError):
            conn.clear_dns_cache('localhost')

    def test_ambigous_verify_ssl_and_ssl_context(self):
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(
                verify_ssl=False,
                ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23),
                loop=self.loop)

    def test_dont_recreate_ssl_context(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        ctx = conn.ssl_context
        self.assertIs(ctx, conn.ssl_context)

    def test_respect_precreated_ssl_context(self):
        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        conn = aiohttp.TCPConnector(loop=self.loop, ssl_context=ctx)
        self.assertIs(ctx, conn.ssl_context)

    def test_close_twice(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

        conn._conns = 'Invalid'  # fill with garbage
        conn.close()
        self.assertTrue(conn.closed)

    def test_close_cancels_cleanup_handle(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task()

        self.assertIsNotNone(conn._cleanup_handle)
        conn.close()
        self.assertIsNone(conn._cleanup_handle)

    def test_ctor_with_default_loop(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        self.addCleanup(loop.close)
        self.addCleanup(asyncio.set_event_loop, None)
        conn = aiohttp.BaseConnector()
        self.assertIs(loop, conn._loop)

    def test_connect_with_limit(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection1 = yield from conn.connect(Req())
            self.assertEqual(connection1._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            acquired = False

            @asyncio.coroutine
            def f():
                nonlocal acquired
                connection2 = yield from conn.connect(Req())
                acquired = True
                self.assertEqual(1, len(conn._acquired[key]))
                connection2.release()

            task = asyncio.async(f(), loop=self.loop)

            yield from asyncio.sleep(0.01, loop=self.loop)
            self.assertFalse(acquired)
            connection1.release()
            yield from asyncio.sleep(0, loop=self.loop)
            self.assertTrue(acquired)
            yield from task
            conn.close()

        self.loop.run_until_complete(go())

    def test_connect_with_limit_cancelled(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection = yield from conn.connect(Req())
            self.assertEqual(connection._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            with self.assertRaises(asyncio.TimeoutError):
                # limit exhausted
                yield from asyncio.wait_for(conn.connect(Req), 0.01,
                                            loop=self.loop)

            connection.close()

        self.loop.run_until_complete(go())

    def test_close_with_acquired_connection(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection = yield from conn.connect(Req())

            self.assertEqual(1, len(conn._acquired))
            conn.close()
            self.assertEqual(0, len(conn._acquired))
            self.assertTrue(conn.closed)
            tr.close.assert_called_with()

            self.assertFalse(connection.closed)
            connection.close()
            self.assertTrue(connection.closed)

        self.loop.run_until_complete(go())

    def test_default_force_close(self):
        connector = aiohttp.BaseConnector(loop=self.loop)
        self.assertFalse(connector.force_close)

    def test_limit_property(self):
        conn = aiohttp.BaseConnector(loop=self.loop, limit=15)
        self.assertEqual(15, conn.limit)
        conn.close()

    def test_limit_property_default(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertIsNone(conn.limit)
        conn.close()
Example #22
0
class TestBaseConnector(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.transport = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser()
        self.response = ClientResponse('get', 'http://base-conn.org')
        self.response._post_init(self.loop)

    def tearDown(self):
        self.response.close()
        self.loop.close()
        gc.collect()

    def test_del(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]
        conns_impl = conn._conns

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'connections': unittest.mock.ANY,
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    def test_del_with_scheduled_cleanup(self):
        conn = aiohttp.BaseConnector(loop=self.loop, keepalive_timeout=0.01)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            yield from asyncio.sleep(0.01)
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    def test_del_with_closed_loop(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)
        self.loop.close()

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        self.assertFalse(transp.close.called)
        self.assertTrue(exc_handler.called)

    def test_del_empty_conector(self):
        conn = aiohttp.BaseConnector(loop=self.loop)

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        del conn

        self.assertFalse(exc_handler.called)

    def test_create_conn(self):

        def go():
            conn = aiohttp.BaseConnector(loop=self.loop)
            with self.assertRaises(NotImplementedError):
                yield from conn._create_connection(object())

        self.loop.run_until_complete(go())

    @unittest.mock.patch('aiohttp.connector.asyncio')
    def test_ctor_loop(self, asyncio):
        session = aiohttp.BaseConnector()
        self.assertIs(session._loop, asyncio.get_event_loop.return_value)

    def test_close(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertFalse(conn.closed)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

    def test_get(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time())]
        self.assertEqual(conn._get(1), (tr, proto))
        conn.close()

    def test_get_expired(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time() - 1000)]
        self.assertEqual(conn._get(1), (None, None))
        self.assertFalse(conn._conns)
        conn.close()

    def test_release(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task = unittest.mock.Mock()
        req = unittest.mock.Mock()
        resp = req.response = unittest.mock.Mock()
        resp._should_close = False

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[1][0], (tr, proto, 10))
        self.assertTrue(conn._start_cleanup_task.called)
        conn.close()

    def test_release_close(self):
        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        cookies = resp.cookies = http.cookies.SimpleCookie()
        cookies['c1'] = 'cookie1'
        cookies['c2'] = 'cookie2'

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)

    def test_get_pop_empty_conns(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('127.0.0.1', 80, False)
        conn._conns[key] = []
        tr, proto = conn._get(key)
        self.assertEqual((None, None), (tr, proto))
        self.assertFalse(conn._conns)

    def test_release_close_do_not_add_to_pool(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        key = ('127.0.0.1', 80, False)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)

    def test_release_close_do_not_delete_existing_connections(self):
        key = ('127.0.0.1', 80, False)
        tr1, proto1 = unittest.mock.Mock(), unittest.mock.Mock()

        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        conn._conns[key] = [(tr1, proto1, 1)]
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr1)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[key], [(tr1, proto1, 1)])
        self.assertTrue(tr.close.called)
        conn.close()

    def test_release_not_started(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns, {1: [(tr, proto, 10)]})
        self.assertFalse(tr.close.called)
        conn.close()

    def test_release_not_opened(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = unittest.mock.Mock()
        req.response.message = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertTrue(tr.close.called)

    def test_connect(self):
        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        proto.is_connected.return_value = True

        class Req:
            host = 'host'
            port = 80
            ssl = False
            response = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('host', 80, False)
        conn._conns[key] = [(tr, proto, self.loop.time())]
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_result((tr, proto))

        connection = self.loop.run_until_complete(conn.connect(Req()))
        self.assertFalse(conn._create_connection.called)
        self.assertEqual(connection._transport, tr)
        self.assertEqual(connection._protocol, proto)
        self.assertIsInstance(connection, Connection)
        connection.close()

    def test_connect_timeout(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_exception(
            asyncio.TimeoutError())

        with self.assertRaises(aiohttp.ClientTimeoutError):
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))

    def test_connect_oserr(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        err = OSError(1, 'permission error')
        conn._create_connection.return_value.set_exception(err)

        with self.assertRaises(aiohttp.ClientOSError) as ctx:
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))
        self.assertEqual(1, ctx.exception.errno)
        self.assertTrue(ctx.exception.strerror.startswith('Cannot connect to'))
        self.assertTrue(ctx.exception.strerror.endswith('[permission error]'))

    def test_start_cleanup_task(self):
        loop = unittest.mock.Mock()
        loop.time.return_value = 1.5
        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        self.assertIsNone(conn._cleanup_handle)

        conn._start_cleanup_task()
        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            12, conn._cleanup)

    def test_cleanup(self):
        testset = {
            1: [(unittest.mock.Mock(), unittest.mock.Mock(), 10),
                (unittest.mock.Mock(), unittest.mock.Mock(), 300),
                (None, unittest.mock.Mock(), 300)],
        }
        testset[1][0][1].is_connected.return_value = True
        testset[1][1][1].is_connected.return_value = False

        loop = unittest.mock.Mock()
        loop.time.return_value = 300
        conn = aiohttp.BaseConnector(loop=loop)
        conn._conns = testset
        existing_handle = conn._cleanup_handle = unittest.mock.Mock()

        conn._cleanup()
        self.assertTrue(existing_handle.cancel.called)
        self.assertEqual(conn._conns, {})
        self.assertIsNone(conn._cleanup_handle)

    def test_cleanup2(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 300)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 300.1

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset
        conn._cleanup()
        self.assertEqual(conn._conns, testset)

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            310, conn._cleanup)
        conn.close()

    def test_cleanup3(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 290.1),
                       (unittest.mock.Mock(), unittest.mock.Mock(), 305.1)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 308.5

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset

        conn._cleanup()
        self.assertEqual(conn._conns, {1: [testset[1][1]]})

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            316, conn._cleanup)
        conn.close()

    def test_tcp_connector_ctor(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        self.assertTrue(conn.verify_ssl)
        self.assertIs(conn.fingerprint, None)

        with self.assertWarns(DeprecationWarning):
            self.assertFalse(conn.resolve)
        self.assertFalse(conn.use_dns_cache)

        self.assertEqual(conn.family, 0)

        with self.assertWarns(DeprecationWarning):
            self.assertEqual(conn.resolved_hosts, {})
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_ctor_fingerprint_valid(self):
        valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
        conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid)
        self.assertEqual(conn.fingerprint, valid)

    def test_tcp_connector_fingerprint_invalid(self):
        invalid = b'\x00'
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid)

    def test_tcp_connector_clear_resolved_hosts(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        with self.assertWarns(DeprecationWarning):
            conn.clear_resolved_hosts()
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_clear_dns_cache(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache()
        self.assertEqual(conn.cached_hosts, {})

    def test_tcp_connector_clear_dns_cache_bad_args(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        with self.assertRaises(ValueError):
            conn.clear_dns_cache('localhost')

    def test_ambigous_verify_ssl_and_ssl_context(self):
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(
                verify_ssl=False,
                ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23),
                loop=self.loop)

    def test_dont_recreate_ssl_context(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        ctx = conn.ssl_context
        self.assertIs(ctx, conn.ssl_context)

    def test_respect_precreated_ssl_context(self):
        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        conn = aiohttp.TCPConnector(loop=self.loop, ssl_context=ctx)
        self.assertIs(ctx, conn.ssl_context)

    def test_close_twice(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

        conn._conns = 'Invalid'  # fill with garbage
        conn.close()
        self.assertTrue(conn.closed)

    def test_close_cancels_cleanup_handle(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task()

        self.assertIsNotNone(conn._cleanup_handle)
        conn.close()
        self.assertIsNone(conn._cleanup_handle)

    def test_ctor_with_default_loop(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        self.addCleanup(loop.close)
        self.addCleanup(asyncio.set_event_loop, None)
        conn = aiohttp.BaseConnector()
        self.assertIs(loop, conn._loop)

    def test_connect_with_limit(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection1 = yield from conn.connect(Req())
            self.assertEqual(connection1._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            acquired = False

            @asyncio.coroutine
            def f():
                nonlocal acquired
                connection2 = yield from conn.connect(Req())
                acquired = True
                self.assertEqual(1, len(conn._acquired[key]))
                connection2.release()

            task = asyncio.async(f(), loop=self.loop)

            yield from asyncio.sleep(0.01, loop=self.loop)
            self.assertFalse(acquired)
            connection1.release()
            yield from asyncio.sleep(0, loop=self.loop)
            self.assertTrue(acquired)
            yield from task
            conn.close()

        self.loop.run_until_complete(go())

    def test_connect_with_limit_cancelled(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection = yield from conn.connect(Req())
            self.assertEqual(connection._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            with self.assertRaises(asyncio.TimeoutError):
                # limit exhausted
                yield from asyncio.wait_for(conn.connect(Req), 0.01,
                                            loop=self.loop)

            connection.close()

        self.loop.run_until_complete(go())

    def test_connect_with_limit_concurrent(self):

        @asyncio.coroutine
        def go():
            proto = unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock(_should_close=False)

            max_connections = 2
            num_connections = 0

            conn = aiohttp.BaseConnector(limit=max_connections, loop=self.loop)

            # Use a real coroutine for _create_connection; a mock would mask
            # problems that only happen when the method yields.

            @asyncio.coroutine
            def create_connection(req):
                nonlocal num_connections
                num_connections += 1
                yield from asyncio.sleep(0, loop=self.loop)

                # Make a new transport mock each time because acquired
                # transports are stored in a set. Reusing the same object
                # messes with the count.
                tr = unittest.mock.Mock()

                return tr, proto

            conn._create_connection = create_connection

            # Simulate something like a crawler. It opens a connection, does
            # something with it, closes it, then creates tasks that make more
            # connections and waits for them to finish. The crawler is started
            # with multiple concurrent requests and stops when it hits a
            # predefined maximum number of requests.

            max_requests = 10
            num_requests = 0
            start_requests = max_connections + 1

            @asyncio.coroutine
            def f(start=True):
                nonlocal num_requests
                if num_requests == max_requests:
                    return
                num_requests += 1
                if not start:
                    connection = yield from conn.connect(Req())
                    yield from asyncio.sleep(0, loop=self.loop)
                    connection.release()
                tasks = [
                    asyncio.async(f(start=False), loop=self.loop)
                    for i in range(start_requests)
                ]
                yield from asyncio.wait(tasks, loop=self.loop)

            yield from f()
            conn.close()

            self.assertEqual(max_connections, num_connections)
Example #23
0
class TestBaseConnector(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.transport = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser()
        self.response = ClientResponse('get', 'http://base-conn.org')
        self.response._post_init(self.loop)

    def tearDown(self):
        self.response.close()
        self.loop.close()
        gc.collect()

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]
        conns_impl = conn._conns

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del_with_scheduled_cleanup(self):
        conn = aiohttp.BaseConnector(loop=self.loop, keepalive_timeout=0.01)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        with self.assertWarns(ResourceWarning):
            del conn
            yield from asyncio.sleep(0.01)
            gc.collect()

        self.assertFalse(conns_impl)
        transp.close.assert_called_with()
        msg = {'connector': unittest.mock.ANY,  # conn was deleted
               'message': 'Unclosed connector'}
        if self.loop.get_debug():
            msg['source_traceback'] = unittest.mock.ANY
        exc_handler.assert_called_with(self.loop, msg)

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del_with_closed_loop(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        transp = unittest.mock.Mock()
        conn._conns['a'] = [(transp, 'proto', 123)]

        conns_impl = conn._conns
        conn._start_cleanup_task()
        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)
        self.loop.close()

        with self.assertWarns(ResourceWarning):
            del conn
            gc.collect()

        self.assertFalse(conns_impl)
        self.assertFalse(transp.close.called)
        self.assertTrue(exc_handler.called)

    @unittest.skipUnless(PY_341, "Requires Python 3.4.1+")
    def test_del_empty_conector(self):
        conn = aiohttp.BaseConnector(loop=self.loop)

        exc_handler = unittest.mock.Mock()
        self.loop.set_exception_handler(exc_handler)

        del conn

        self.assertFalse(exc_handler.called)

    def test_create_conn(self):

        def go():
            conn = aiohttp.BaseConnector(loop=self.loop)
            with self.assertRaises(NotImplementedError):
                yield from conn._create_connection(object())

        self.loop.run_until_complete(go())

    @unittest.mock.patch('aiohttp.connector.asyncio')
    def test_ctor_loop(self, asyncio):
        session = aiohttp.BaseConnector()
        self.assertIs(session._loop, asyncio.get_event_loop.return_value)

    def test_close(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertFalse(conn.closed)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

    def test_get(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time())]
        self.assertEqual(conn._get(1), (tr, proto))
        conn.close()

    def test_get_expired(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertEqual(conn._get(1), (None, None))

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._conns[1] = [(tr, proto, self.loop.time() - 1000)]
        self.assertEqual(conn._get(1), (None, None))
        self.assertFalse(conn._conns)
        conn.close()

    def test_release(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task = unittest.mock.Mock()
        req = unittest.mock.Mock()
        resp = req.response = unittest.mock.Mock()
        resp._should_close = False

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[1][0], (tr, proto, 10))
        self.assertTrue(conn._start_cleanup_task.called)
        conn.close()

    def test_release_close(self):
        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        cookies = resp.cookies = http.cookies.SimpleCookie()
        cookies['c1'] = 'cookie1'
        cookies['c2'] = 'cookie2'

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)

    def test_get_pop_empty_conns(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('127.0.0.1', 80, False)
        conn._conns[key] = []
        tr, proto = conn._get(key)
        self.assertEqual((None, None), (tr, proto))
        self.assertFalse(conn._conns)

    def test_release_close_do_not_add_to_pool(self):
        # see issue #473
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        key = ('127.0.0.1', 80, False)

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertFalse(conn._conns)

    def test_release_close_do_not_delete_existing_connections(self):
        key = ('127.0.0.1', 80, False)
        tr1, proto1 = unittest.mock.Mock(), unittest.mock.Mock()

        with self.assertWarns(DeprecationWarning):
            conn = aiohttp.BaseConnector(share_cookies=True, loop=self.loop)
        conn._conns[key] = [(tr1, proto1, 1)]
        req = unittest.mock.Mock()
        resp = unittest.mock.Mock()
        resp.message.should_close = True
        req.response = resp

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        conn._acquired[key].add(tr1)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns[key], [(tr1, proto1, 1)])
        self.assertTrue(tr.close.called)
        conn.close()

    def test_release_not_started(self):
        self.loop.time = mock.Mock(return_value=10)

        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertEqual(conn._conns, {1: [(tr, proto, 10)]})
        self.assertFalse(tr.close.called)
        conn.close()

    def test_release_not_opened(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        req = unittest.mock.Mock()
        req.response = unittest.mock.Mock()
        req.response.message = None

        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        key = 1
        conn._acquired[key].add(tr)
        conn._release(key, req, tr, proto)
        self.assertTrue(tr.close.called)

    def test_connect(self):
        tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
        proto.is_connected.return_value = True

        class Req:
            host = 'host'
            port = 80
            ssl = False
            response = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        key = ('host', 80, False)
        conn._conns[key] = [(tr, proto, self.loop.time())]
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_result((tr, proto))

        connection = self.loop.run_until_complete(conn.connect(Req()))
        self.assertFalse(conn._create_connection.called)
        self.assertEqual(connection._transport, tr)
        self.assertEqual(connection._protocol, proto)
        self.assertIsInstance(connection, Connection)
        connection.close()

    def test_connect_timeout(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        conn._create_connection.return_value.set_exception(
            asyncio.TimeoutError())

        with self.assertRaises(aiohttp.ClientTimeoutError):
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))

    def test_connect_oserr(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._create_connection = unittest.mock.Mock()
        conn._create_connection.return_value = asyncio.Future(loop=self.loop)
        err = OSError(1, 'permission error')
        conn._create_connection.return_value.set_exception(err)

        with self.assertRaises(aiohttp.ClientOSError) as ctx:
            req = unittest.mock.Mock()
            self.loop.run_until_complete(conn.connect(req))
        self.assertEqual(1, ctx.exception.errno)
        self.assertTrue(ctx.exception.strerror.startswith('Cannot connect to'))
        self.assertTrue(ctx.exception.strerror.endswith('[permission error]'))

    def test_start_cleanup_task(self):
        loop = unittest.mock.Mock()
        loop.time.return_value = 1.5
        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        self.assertIsNone(conn._cleanup_handle)

        conn._start_cleanup_task()
        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            12, conn._cleanup)

    def test_cleanup(self):
        testset = {
            1: [(unittest.mock.Mock(), unittest.mock.Mock(), 10),
                (unittest.mock.Mock(), unittest.mock.Mock(), 300),
                (None, unittest.mock.Mock(), 300)],
        }
        testset[1][0][1].is_connected.return_value = True
        testset[1][1][1].is_connected.return_value = False

        loop = unittest.mock.Mock()
        loop.time.return_value = 300
        conn = aiohttp.BaseConnector(loop=loop)
        conn._conns = testset
        existing_handle = conn._cleanup_handle = unittest.mock.Mock()

        conn._cleanup()
        self.assertTrue(existing_handle.cancel.called)
        self.assertEqual(conn._conns, {})
        self.assertIsNone(conn._cleanup_handle)

    def test_cleanup2(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 300)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 300.1

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset
        conn._cleanup()
        self.assertEqual(conn._conns, testset)

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            310, conn._cleanup)
        conn.close()

    def test_cleanup3(self):
        testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 290.1),
                       (unittest.mock.Mock(), unittest.mock.Mock(), 305.1)]}
        testset[1][0][1].is_connected.return_value = True

        loop = unittest.mock.Mock()
        loop.time.return_value = 308.5

        conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=10)
        conn._conns = testset

        conn._cleanup()
        self.assertEqual(conn._conns, {1: [testset[1][1]]})

        self.assertIsNotNone(conn._cleanup_handle)
        loop.call_at.assert_called_with(
            316, conn._cleanup)
        conn.close()

    def test_tcp_connector_ctor(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        self.assertTrue(conn.verify_ssl)
        self.assertIs(conn.fingerprint, None)

        with self.assertWarns(DeprecationWarning):
            self.assertFalse(conn.resolve)
        self.assertFalse(conn.use_dns_cache)

        self.assertEqual(conn.family, socket.AF_INET)

        with self.assertWarns(DeprecationWarning):
            self.assertEqual(conn.resolved_hosts, {})
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_ctor_fingerprint_valid(self):
        valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
        conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid)
        self.assertEqual(conn.fingerprint, valid)

    def test_tcp_connector_fingerprint_invalid(self):
        invalid = b'\x00'
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid)

    def test_tcp_connector_clear_resolved_hosts(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        conn.clear_resolved_hosts('localhost', 123)
        self.assertEqual(
            conn.resolved_hosts, {('localhost', 124): info})
        with self.assertWarns(DeprecationWarning):
            conn.clear_resolved_hosts()
        self.assertEqual(conn.resolved_hosts, {})

    def test_tcp_connector_clear_dns_cache(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        info = object()
        conn._cached_hosts[('localhost', 123)] = info
        conn._cached_hosts[('localhost', 124)] = info
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache('localhost', 123)
        self.assertEqual(
            conn.cached_hosts, {('localhost', 124): info})
        conn.clear_dns_cache()
        self.assertEqual(conn.cached_hosts, {})

    def test_tcp_connector_clear_dns_cache_bad_args(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        with self.assertRaises(ValueError):
            conn.clear_dns_cache('localhost')

    def test_ambigous_verify_ssl_and_ssl_context(self):
        with self.assertRaises(ValueError):
            aiohttp.TCPConnector(
                verify_ssl=False,
                ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23),
                loop=self.loop)

    def test_dont_recreate_ssl_context(self):
        conn = aiohttp.TCPConnector(loop=self.loop)
        ctx = conn.ssl_context
        self.assertIs(ctx, conn.ssl_context)

    def test_respect_precreated_ssl_context(self):
        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        conn = aiohttp.TCPConnector(loop=self.loop, ssl_context=ctx)
        self.assertIs(ctx, conn.ssl_context)

    def test_close_twice(self):
        tr = unittest.mock.Mock()

        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._conns[1] = [(tr, object(), object())]
        conn.close()

        self.assertFalse(conn._conns)
        self.assertTrue(tr.close.called)
        self.assertTrue(conn.closed)

        conn._conns = 'Invalid'  # fill with garbage
        conn.close()
        self.assertTrue(conn.closed)

    def test_close_cancels_cleanup_handle(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        conn._start_cleanup_task()

        self.assertIsNotNone(conn._cleanup_handle)
        conn.close()
        self.assertIsNone(conn._cleanup_handle)

    def test_ctor_with_default_loop(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        self.addCleanup(loop.close)
        self.addCleanup(asyncio.set_event_loop, None)
        conn = aiohttp.BaseConnector()
        self.assertIs(loop, conn._loop)

    def test_connect_with_limit(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection1 = yield from conn.connect(Req())
            self.assertEqual(connection1._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            acquired = False

            @asyncio.coroutine
            def f():
                nonlocal acquired
                connection2 = yield from conn.connect(Req())
                acquired = True
                self.assertEqual(1, len(conn._acquired[key]))
                connection2.release()

            task = asyncio.async(f(), loop=self.loop)

            yield from asyncio.sleep(0.01, loop=self.loop)
            self.assertFalse(acquired)
            connection1.release()
            yield from asyncio.sleep(0, loop=self.loop)
            self.assertTrue(acquired)
            yield from task
            conn.close()

        self.loop.run_until_complete(go())

    def test_connect_with_limit_cancelled(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection = yield from conn.connect(Req())
            self.assertEqual(connection._transport, tr)

            self.assertEqual(1, len(conn._acquired[key]))

            with self.assertRaises(asyncio.TimeoutError):
                # limit exhausted
                yield from asyncio.wait_for(conn.connect(Req), 0.01,
                                            loop=self.loop)

            connection.close()

        self.loop.run_until_complete(go())

    def test_close_with_acquired_connection(self):

        @asyncio.coroutine
        def go():
            tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
            proto.is_connected.return_value = True

            class Req:
                host = 'host'
                port = 80
                ssl = False
                response = unittest.mock.Mock()

            conn = aiohttp.BaseConnector(loop=self.loop, limit=1)
            key = ('host', 80, False)
            conn._conns[key] = [(tr, proto, self.loop.time())]
            conn._create_connection = unittest.mock.Mock()
            conn._create_connection.return_value = asyncio.Future(
                loop=self.loop)
            conn._create_connection.return_value.set_result((tr, proto))

            connection = yield from conn.connect(Req())

            self.assertEqual(1, len(conn._acquired))
            conn.close()
            self.assertEqual(0, len(conn._acquired))
            self.assertTrue(conn.closed)
            tr.close.assert_called_with()

            self.assertFalse(connection.closed)
            connection.close()
            self.assertTrue(connection.closed)

        self.loop.run_until_complete(go())

    def test_default_force_close(self):
        connector = aiohttp.BaseConnector(loop=self.loop)
        self.assertFalse(connector.force_close)

    def test_limit_property(self):
        conn = aiohttp.BaseConnector(loop=self.loop, limit=15)
        self.assertEqual(15, conn.limit)
        conn.close()

    def test_limit_property_default(self):
        conn = aiohttp.BaseConnector(loop=self.loop)
        self.assertIsNone(conn.limit)
        conn.close()
Example #24
0
class ClientResponseTests(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.connection = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser(loop=self.loop)
        self.response = ClientResponse('get', 'http://python.org')

    def tearDown(self):
        self.loop.close()

    def test_del(self):
        response = ClientResponse('get', 'http://python.org')

        connection = unittest.mock.Mock()
        response._setup_connection(connection)
        with self.assertWarns(ResourceWarning):
            del response

        connection.close.assert_called_with()

    def test_close(self):
        self.response.connection = self.connection
        self.response.close()
        self.assertIsNone(self.response.connection)
        self.assertTrue(self.connection.release.called)
        self.response.close()
        self.response.close()

    def test_wait_for_100(self):
        response = ClientResponse(
            'get', 'http://python.org', continue100=object())
        self.assertTrue(response.waiting_for_continue())
        response = ClientResponse(
            'get', 'http://python.org')
        self.assertFalse(response.waiting_for_continue())

    def test_repr(self):
        self.response.status = 200
        self.response.reason = 'Ok'
        self.assertIn(
            '<ClientResponse(http://python.org) [200 Ok]>',
            repr(self.response))

    def test_read_and_release_connection(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result(b'payload')
            content.read.side_effect = second_call
            return fut
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.read())
        self.assertEqual(res, b'payload')
        self.assertTrue(self.response.close.called)

    def test_read_and_release_connection_with_error(self):
        content = self.response.content = unittest.mock.Mock()
        content.read.return_value = asyncio.Future(loop=self.loop)
        content.read.return_value.set_exception(ValueError)
        self.response.close = unittest.mock.Mock()

        self.assertRaises(
            ValueError,
            self.loop.run_until_complete, self.response.read())
        self.response.close.assert_called_with(True)

    def test_release(self):
        fut = asyncio.Future(loop=self.loop)
        fut.set_result(b'')
        content = self.response.content = unittest.mock.Mock()
        content.readany.return_value = fut
        self.response.close = unittest.mock.Mock()

        self.loop.run_until_complete(self.response.release())
        self.assertTrue(self.response.close.called)

    def test_read_and_close(self):
        self.response.read = unittest.mock.Mock()
        self.response.read.return_value = asyncio.Future(loop=self.loop)
        self.response.read.return_value.set_result(b'data')

        with self.assertWarns(DeprecationWarning):
            res = self.loop.run_until_complete(self.response.read_and_close())
        self.assertEqual(res, b'data')
        self.assertTrue(self.response.read.called)

    def test_read_decode_deprecated(self):
        self.response._content = b'data'
        self.response.json = unittest.mock.Mock()
        self.response.json.return_value = asyncio.Future(loop=self.loop)
        self.response.json.return_value.set_result('json')

        with self.assertWarns(DeprecationWarning):
            res = self.loop.run_until_complete(self.response.read(decode=True))
        self.assertEqual(res, 'json')
        self.assertTrue(self.response.json.called)

    def test_text(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            content.read.side_effect = second_call
            return fut
        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=cp1251'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.text())
        self.assertEqual(res, '{"тест": "пройден"}')
        self.assertTrue(self.response.close.called)

    def test_text_custom_encoding(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            content.read.side_effect = second_call
            return fut
        self.response.headers = {
            'CONTENT-TYPE': 'application/json'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(
            self.response.text(encoding='cp1251'))
        self.assertEqual(res, '{"тест": "пройден"}')
        self.assertTrue(self.response.close.called)

    def test_text_detect_encoding(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            content.read.side_effect = second_call
            return fut
        self.response.headers = {'CONTENT-TYPE': 'application/json'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        if chardet is None:
            self.assertRaises(UnicodeDecodeError,
                              self.loop.run_until_complete,
                              self.response.text())
        else:
            res = self.loop.run_until_complete(self.response.text())
            self.assertEqual(res, '{"тест": "пройден"}')
            self.assertTrue(self.response.close.called)

    def test_json(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            content.read.side_effect = second_call
            return fut
        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=cp1251'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.json())
        self.assertEqual(res, {'тест': 'пройден'})
        self.assertTrue(self.response.close.called)

    def test_json_custom_loader(self):
        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=cp1251'}
        self.response._content = b'data'

        def custom(content):
            return content + '-custom'

        res = self.loop.run_until_complete(self.response.json(loads=custom))
        self.assertEqual(res, 'data-custom')

    @unittest.mock.patch('aiohttp.client.client_log')
    def test_json_no_content(self, m_log):
        self.response.headers = {
            'CONTENT-TYPE': 'data/octet-stream'}
        self.response._content = b''
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.json())
        self.assertIsNone(res)
        m_log.warning.assert_called_with(
            'Attempt to decode JSON with unexpected mimetype: %s',
            'data/octet-stream')

    def test_json_override_encoding(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            content.read.side_effect = second_call
            return fut
        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=utf8'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(
            self.response.json(encoding='cp1251'))
        self.assertEqual(res, {'тест': 'пройден'})
        self.assertTrue(self.response.close.called)

    def test_json_detect_encoding(self):
        def side_effect(*args, **kwargs):
            def second_call(*args, **kwargs):
                raise aiohttp.EofStream
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            content.read.side_effect = second_call
            return fut
        self.response.headers = {'CONTENT-TYPE': 'application/json'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        if chardet is None:
            self.assertRaises(UnicodeDecodeError,
                              self.loop.run_until_complete,
                              self.response.json())
        else:
            res = self.loop.run_until_complete(self.response.json())
            self.assertEqual(res, {'тест': 'пройден'})
            self.assertTrue(self.response.close.called)

    def test_override_flow_control(self):
        class MyResponse(ClientResponse):
            flow_control_class = aiohttp.FlowControlDataQueue
        response = MyResponse('get', 'http://python.org')
        response._setup_connection(self.connection)
        self.assertIsInstance(response.content, aiohttp.FlowControlDataQueue)
        with self.assertWarns(ResourceWarning):
            del response
Example #25
0
class ClientResponseTests(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.connection = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser(loop=self.loop)
        self.response = ClientResponse('get', 'http://python.org')

    def tearDown(self):
        self.loop.close()

    def test_del(self):
        response = ClientResponse('get', 'http://python.org')

        connection = unittest.mock.Mock()
        response._setup_connection(connection)
        with self.assertWarns(ResourceWarning):
            del response

        connection.close.assert_called_with()

    def test_close(self):
        self.response.connection = self.connection
        self.response.close()
        self.assertIsNone(self.response.connection)
        self.assertTrue(self.connection.release.called)
        self.response.close()
        self.response.close()

    def test_wait_for_100(self):
        response = ClientResponse('get',
                                  'http://python.org',
                                  continue100=object())
        self.assertTrue(response.waiting_for_continue())
        response = ClientResponse('get', 'http://python.org')
        self.assertFalse(response.waiting_for_continue())

    def test_repr(self):
        self.response.status = 200
        self.response.reason = 'Ok'
        self.assertIn('<ClientResponse(http://python.org) [200 Ok]>',
                      repr(self.response))

    def test_read_and_release_connection(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result(b'payload')
            return fut

        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.read())
        self.assertEqual(res, b'payload')
        self.assertTrue(self.response.close.called)

    def test_read_and_release_connection_with_error(self):
        content = self.response.content = unittest.mock.Mock()
        content.read.return_value = asyncio.Future(loop=self.loop)
        content.read.return_value.set_exception(ValueError)
        self.response.close = unittest.mock.Mock()

        self.assertRaises(ValueError, self.loop.run_until_complete,
                          self.response.read())
        self.response.close.assert_called_with(True)

    def test_release(self):
        fut = asyncio.Future(loop=self.loop)
        fut.set_result(b'')
        content = self.response.content = unittest.mock.Mock()
        content.readany.return_value = fut
        self.response.close = unittest.mock.Mock()

        self.loop.run_until_complete(self.response.release())
        self.assertTrue(self.response.close.called)

    def test_read_and_close(self):
        self.response.read = unittest.mock.Mock()
        self.response.read.return_value = asyncio.Future(loop=self.loop)
        self.response.read.return_value.set_result(b'data')

        with self.assertWarns(DeprecationWarning):
            res = self.loop.run_until_complete(self.response.read_and_close())
        self.assertEqual(res, b'data')
        self.assertTrue(self.response.read.called)

    def test_read_decode_deprecated(self):
        self.response._content = b'data'
        self.response.json = unittest.mock.Mock()
        self.response.json.return_value = asyncio.Future(loop=self.loop)
        self.response.json.return_value.set_result('json')

        with self.assertWarns(DeprecationWarning):
            res = self.loop.run_until_complete(self.response.read(decode=True))
        self.assertEqual(res, 'json')
        self.assertTrue(self.response.json.called)

    def test_text(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            return fut

        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=cp1251'
        }
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.text())
        self.assertEqual(res, '{"тест": "пройден"}')
        self.assertTrue(self.response.close.called)

    def test_text_custom_encoding(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            return fut

        self.response.headers = {'CONTENT-TYPE': 'application/json'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(
            self.response.text(encoding='cp1251'))
        self.assertEqual(res, '{"тест": "пройден"}')
        self.assertTrue(self.response.close.called)

    def test_text_detect_encoding(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            return fut

        self.response.headers = {'CONTENT-TYPE': 'application/json'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.text())
        self.assertEqual(res, '{"тест": "пройден"}')
        self.assertTrue(self.response.close.called)

    def test_json(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            return fut

        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=cp1251'
        }
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.json())
        self.assertEqual(res, {'тест': 'пройден'})
        self.assertTrue(self.response.close.called)

    def test_json_custom_loader(self):
        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=cp1251'
        }
        self.response._content = b'data'

        def custom(content):
            return content + '-custom'

        res = self.loop.run_until_complete(self.response.json(loads=custom))
        self.assertEqual(res, 'data-custom')

    @unittest.mock.patch('aiohttp.client.client_logger')
    def test_json_no_content(self, m_log):
        self.response.headers = {'CONTENT-TYPE': 'data/octet-stream'}
        self.response._content = b''
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.json())
        self.assertIsNone(res)
        m_log.warning.assert_called_with(
            'Attempt to decode JSON with unexpected mimetype: %s',
            'data/octet-stream')

    def test_json_override_encoding(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            return fut

        self.response.headers = {
            'CONTENT-TYPE': 'application/json;charset=utf8'
        }
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(
            self.response.json(encoding='cp1251'))
        self.assertEqual(res, {'тест': 'пройден'})
        self.assertTrue(self.response.close.called)

    def test_json_detect_encoding(self):
        def side_effect(*args, **kwargs):
            fut = asyncio.Future(loop=self.loop)
            fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
            return fut

        self.response.headers = {'CONTENT-TYPE': 'application/json'}
        content = self.response.content = unittest.mock.Mock()
        content.read.side_effect = side_effect
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.json())
        self.assertEqual(res, {'тест': 'пройден'})
        self.assertTrue(self.response.close.called)

    def test_override_flow_control(self):
        class MyResponse(ClientResponse):
            flow_control_class = aiohttp.FlowControlDataQueue

        response = MyResponse('get', 'http://python.org')
        response._setup_connection(self.connection)
        self.assertIsInstance(response.content, aiohttp.FlowControlDataQueue)
        with self.assertWarns(ResourceWarning):
            del response
Example #26
0
def response(loop):
    resp = ClientResponse('get', 'http://base-conn.org')
    resp._post_init(loop)
    return resp
Example #27
0
class ClientResponseTests(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.connection = unittest.mock.Mock()
        self.stream = aiohttp.StreamParser(loop=self.loop)
        self.response = ClientResponse('get', 'http://python.org')

    def tearDown(self):
        self.loop.close()

    def test_del(self):
        response = ClientResponse('get', 'http://python.org')

        connection = unittest.mock.Mock()
        response._setup_connection(connection)
        del response

        connection.close.assert_called_with()

    def test_close(self):
        self.response.connection = self.connection
        self.response.close()
        self.assertIsNone(self.response.connection)
        self.assertTrue(self.connection.release.called)
        self.response.close()
        self.response.close()

    def test_wait_for_100(self):
        response = ClientResponse(
            'get', 'http://python.org', continue100=object())
        self.assertTrue(response.waiting_for_continue())
        response = ClientResponse(
            'get', 'http://python.org')
        self.assertFalse(response.waiting_for_continue())

    def test_repr(self):
        self.response.status = 200
        self.response.reason = 'Ok'
        self.assertIn(
            '<ClientResponse(http://python.org) [200 Ok]>',
            repr(self.response))

    def test_read_and_close(self):
        self.response.read = unittest.mock.Mock()
        self.response.read.return_value = asyncio.Future(loop=self.loop)
        self.response.read.return_value.set_result(b'payload')
        self.response.close = unittest.mock.Mock()

        res = self.loop.run_until_complete(self.response.read_and_close())
        self.assertEqual(res, b'payload')
        self.assertTrue(self.response.read.called)
        self.assertTrue(self.response.close.called)

    def test_read_and_close_with_error(self):
        self.response.read = unittest.mock.Mock()
        self.response.read.return_value = asyncio.Future(loop=self.loop)
        self.response.read.return_value.set_exception(ValueError)
        self.response.close = unittest.mock.Mock()

        self.assertRaises(
            ValueError,
            self.loop.run_until_complete, self.response.read_and_close())
        self.assertTrue(self.response.read.called)
        self.response.close.assert_called_with(True)
Example #28
0
 def cb():
     fut = asyncio.Future(loop=self._loop)
     resp = ClientResponse("GET", "foo")
     resp.status = status
     fut.set_result(resp)
     return fut
Example #29
0
 def cb():
     fut = asyncio.Future(loop=self._loop)
     resp = ClientResponse('GET', yarl.URL('foo'))
     resp.status = status
     fut.set_result(resp)
     return fut