示例#1
0
class ServerTests(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)
        self.server = RESTServer(debug=True,
                                 keep_alive=75,
                                 hostname='127.0.0.1')
        self.port = None
        rest = REST(self)
        self.server.add_url('POST', '/post/{id}', rest.func_POST)
        self.server.add_url('POST', '/create', rest.create)
        self.server.add_url('GET', '/post/{id}', rest.func_GET)
        self.server.add_url('GET', '/post/{id}/2', rest.func_GET2)
        self.server.add_url('GET', '/cookie/{value}', rest.coro_set_cookie)
        self.server.add_url('GET', '/get_cookie/', rest.func_get_cookie)
        self.server.add_url('GET', '/check/no/session', rest.check_session)

    def tearDown(self):
        self.loop.close()

    def test_simple_POST(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST',
                url,
                data=json.dumps({
                    'q': 'val'
                }).encode('utf-8'),
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            self.assertEqual(b'{"success": true}', data)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_simple_GET(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request('GET', url, loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            self.assertEqual(b'{"success": true}', data)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_GET_with_query_string(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123/2?a=1&b=2'.format(port)

        def query():
            response = yield from aiohttp.request('GET', url, loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            dct = json.loads(data.decode('utf-8'))
            self.assertEqual({
                'success': True,
                'args': ['a', 'b'],
            }, dct)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_set_cookie(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/cookie/123'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request('GET', url, loop=self.loop)
            yield from response.read()
            self.assertEqual(200, response.status)
            self.assertIn('test_cookie', response.cookies)
            self.assertEqual(response.cookies['test_cookie'].value, '123')

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_get_cookie(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/get_cookie/'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request(
                'GET', url, cookies={'test_cookie': 'value'}, loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            dct = json.loads(data.decode('utf-8'))
            self.assertEqual({
                'success': True,
                'cookie': 'value',
            }, dct)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_accept_encoding__deflate(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request(
                'GET',
                url,
                headers={'ACCEPT-ENCODING': 'deflate'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            dct = json.loads(data.decode('utf-8'))
            self.assertEqual({'success': True}, dct)
            headers = response.message.headers
            enc = headers['CONTENT-ENCODING']
            self.assertEqual('deflate', enc)

        self.loop.run_until_complete(query())

    def test_accept_encoding__gzip(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request(
                'GET',
                url,
                headers={'ACCEPT-ENCODING': 'gzip'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            yield from response.read()
            # dct = json.loads(data.decode('utf-8'))
            # self.assertEqual({'success': True}, dct)
            headers = response.message.headers
            enc = headers['CONTENT-ENCODING']
            self.assertEqual('gzip', enc)

        self.loop.run_until_complete(query())

    def test_POST_without_body(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST',
                url,
                data='',
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(400, response.status)
            data = yield from response.read()
            j = json.loads(data.decode('utf-8'))
            self.assertEqual(
                {
                    "error_code": 400,
                    "error_reason": "Request has no body",
                    "error": {}
                }, j)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_POST_malformed_json(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST',
                url,
                data='{dfsdf}2',
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(400, response.status)
            data = yield from response.read()
            j = json.loads(data.decode('utf-8'))
            self.assertEqual(
                {
                    "error_code": 400,
                    "error_reason": "JSON body can not be decoded",
                    "error": {}
                }, j)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_POST_nonutf8_json(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST',
                url,
                data='{"русский": "текст"}'.encode('1251'),
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(400, response.status)
            data = yield from response.read()
            j = json.loads(data.decode('utf-8'))
            self.assertEqual(
                {
                    "error_code": 400,
                    "error_reason": "JSON body is not utf-8 encoded",
                    "error": {}
                }, j)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_status_code(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/create'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST',
                url,
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(201, response.status)
            data = yield from response.read()
            self.assertEqual(b'{"created": true}', data)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_no_session(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/check/no/session'.format(port)

        @asyncio.coroutine
        def query():
            resp = yield from aiohttp.request('GET', url, loop=self.loop)
            self.assertEqual(200, resp.status)
            yield from resp.read()

        self.loop.run_until_complete(query())
示例#2
0
class ServerTests(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)
        auth_policy = DictionaryAuthorizationPolicy({'chris': ('read', )})
        self.server = RESTServer(debug=True,
                                 keep_alive=75,
                                 hostname='127.0.0.1',
                                 identity_policy=CookieIdentityPolicy(),
                                 auth_policy=auth_policy)
        self.port = None
        rest = REST()
        self.server.add_url('GET', '/auth/{permission}', rest.handler)

    def tearDown(self):
        self.loop.close()

    def test_identity_missing(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/auth'.format(port)

        def query():
            resp = yield from aiohttp.request('GET',
                                              url + '/read',
                                              cookies={},
                                              loop=self.loop)
            json_data = yield from resp.json()
            self.assertEqual(json_data['error'], 'Identity not found')

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_user_missing(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/auth'.format(port)

        def query():
            resp = yield from aiohttp.request(
                'GET',
                url + '/read',
                cookies={'user_id': 'john'},  # not chris
                loop=self.loop)
            json_data = yield from resp.json()
            self.assertEqual(json_data['error'], 'User not found')

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_permission_missing(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/auth'.format(port)

        def query():
            resp = yield from aiohttp.request(
                'GET',
                url + '/write',  # not read
                cookies={'user_id': 'chris'},
                loop=self.loop)
            json_data = yield from resp.json()
            self.assertEqual(json_data['allowed'], False)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_permission_present(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/auth'.format(port)

        def query():
            resp = yield from aiohttp.request('GET',
                                              url + '/read',
                                              cookies={'user_id': 'chris'},
                                              loop=self.loop)
            json_data = yield from resp.json()
            self.assertEqual(json_data['allowed'], True)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_coockie_policy(self):
        identity_policy = CookieIdentityPolicy()
        request = Request('host',
                          aiohttp.RawRequestMessage('GET', '/post/123/', '1.1',
                                                    {}, True, None),
                          None,
                          loop=self.loop)

        @asyncio.coroutine
        def f():
            user_id = yield from identity_policy.identify(request)
            self.assertFalse(user_id)

            yield from identity_policy.remember(request, 'anton')
            # emulate response-request cycle
            request._cookies = request.response.cookies.copy()
            user_id = yield from identity_policy.identify(request)
            self.assertEqual(user_id.value, 'anton')

            yield from identity_policy.forget(request)
            # emulate response-request cycle
            request._cookies = request.response.cookies.copy()
            user_id = yield from identity_policy.identify(request)
            self.assertFalse(user_id.value)

        self.loop.run_until_complete(f())
示例#3
0
class CookieSessionTests(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        session_factory = CookieSessionFactory(secret_key=b'secret',
                                               cookie_name='test_cookie',
                                               dumps=json.dumps,
                                               loads=json.loads)
        self.server = RESTServer(debug=True,
                                 keep_alive=75,
                                 hostname='localhost',
                                 session_factory=session_factory)
        rest = REST(self)

        self.server.add_url('GET', '/init', rest.init_session)
        self.server.add_url('GET', '/get', rest.get_from_session)
        self.server.add_url('GET', '/counter', rest.counter)
        self.server.add_url('GET', '/counter/{start}', rest.counter)

    def tearDown(self):
        self.loop.close()

    @contextlib.contextmanager
    def run_server(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        sock = next(iter(srv.sockets))
        host, port = sock.getsockname()
        base_url = 'http://{}:{}'.format(host, port)

        yield (srv, base_url)

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    @mock.patch('aiorest.session.cookie_session.time')
    def test_init_session(self, time_mock):
        time_mock.time.return_value = 1
        with self.run_server() as (srv, base_url):
            url = base_url + '/init'

            @asyncio.coroutine
            def query():
                resp = yield from aiohttp.request('GET', url, loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                cookies = {k: v.value for k, v in resp.cookies.items()}
                value = make_cookie({'foo': 'bar'}, 1)
                self.assertEqual(cookies, {'test_cookie': value})

            self.loop.run_until_complete(query())

    @mock.patch('aiorest.session.cookie_session.time')
    def test_get_from_session(self, time_mock):
        time_mock.time.return_value = 1
        with self.run_server() as (srv, base_url):

            url = base_url + '/get'

            @asyncio.coroutine
            def query():
                resp = yield from aiohttp.request(
                    'GET',
                    url,
                    cookies={'test_cookie': make_cookie({'foo': 'bar'}, 1)},
                    loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)

            self.loop.run_until_complete(query())

    def test_full_cycle(self):
        with self.run_server() as (srv, base_url):
            url = base_url + '/counter'

            @asyncio.coroutine
            def queries():
                connector = aiohttp.TCPConnector(share_cookies=True,
                                                 loop=self.loop)
                # initiate session; set start value to 2
                resp = yield from aiohttp.request('GET',
                                                  url + "/2",
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 3})

                # do increment
                resp = yield from aiohttp.request('GET',
                                                  url,
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 4})

                # try to override start value
                resp = yield from aiohttp.request('GET',
                                                  url + '/3',
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 5})

                # session deleted; try count
                resp = yield from aiohttp.request('GET',
                                                  url,
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 1})

            self.loop.run_until_complete(queries())
示例#4
0
class CookieSessionTests(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        session_factory = CookieSessionFactory(secret_key=b'secret',
                                               cookie_name='test_cookie',
                                               dumps=json.dumps,
                                               loads=json.loads)
        self.server = RESTServer(debug=True, keep_alive=75,
                                 hostname='localhost',
                                 session_factory=session_factory)
        rest = REST(self)

        self.server.add_url('GET', '/init', rest.init_session)
        self.server.add_url('GET', '/get', rest.get_from_session)
        self.server.add_url('GET', '/counter', rest.counter)
        self.server.add_url('GET', '/counter/{start}', rest.counter)

    def tearDown(self):
        self.loop.close()

    @contextlib.contextmanager
    def run_server(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        sock = next(iter(srv.sockets))
        host, port = sock.getsockname()
        base_url = 'http://{}:{}'.format(host, port)

        yield (srv, base_url)

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    @mock.patch('aiorest.session.cookie_session.time')
    def test_init_session(self, time_mock):
        time_mock.time.return_value = 1
        with self.run_server() as (srv, base_url):
            url = base_url + '/init'

            @asyncio.coroutine
            def query():
                resp = yield from aiohttp.request('GET', url, loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                cookies = {k: v.value for k, v in resp.cookies.items()}
                value = make_cookie({'foo': 'bar'}, 1)
                self.assertEqual(cookies, {'test_cookie': value})

            self.loop.run_until_complete(query())

    @mock.patch('aiorest.session.cookie_session.time')
    def test_get_from_session(self, time_mock):
        time_mock.time.return_value = 1
        with self.run_server() as (srv, base_url):

            url = base_url + '/get'

            @asyncio.coroutine
            def query():
                resp = yield from aiohttp.request(
                    'GET', url,
                    cookies={'test_cookie': make_cookie({'foo': 'bar'}, 1)},
                    loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)

            self.loop.run_until_complete(query())

    def test_full_cycle(self):
        with self.run_server() as (srv, base_url):
            url = base_url + '/counter'

            @asyncio.coroutine
            def queries():
                connector = aiohttp.TCPConnector(share_cookies=True,
                                                 loop=self.loop)
                # initiate session; set start value to 2
                resp = yield from aiohttp.request('GET', url + "/2",
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 3})

                # do increment
                resp = yield from aiohttp.request('GET', url,
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 4})

                # try to override start value
                resp = yield from aiohttp.request('GET', url + '/3',
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 5})

                # session deleted; try count
                resp = yield from aiohttp.request('GET', url,
                                                  connector=connector,
                                                  loop=self.loop)
                data = yield from resp.json()
                self.assertEqual(resp.status, 200)
                self.assertEqual(data, {'result': 1})

            self.loop.run_until_complete(queries())
示例#5
0
class ServerTests(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)
        self.server = RESTServer(debug=True, keep_alive=75,
                                 hostname='127.0.0.1')
        self.port = None
        rest = REST(self)
        self.server.add_url('POST', '/post/{id}', rest.func_POST)
        self.server.add_url('POST', '/create', rest.create)
        self.server.add_url('GET', '/post/{id}', rest.func_GET)
        self.server.add_url('GET', '/post/{id}/2', rest.func_GET2)
        self.server.add_url('GET', '/cookie/{value}', rest.coro_set_cookie)
        self.server.add_url('GET', '/get_cookie/', rest.func_get_cookie)
        self.server.add_url('GET', '/check/no/session', rest.check_session)

    def tearDown(self):
        self.loop.close()

    def test_simple_POST(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST', url,
                data=json.dumps({'q': 'val'}).encode('utf-8'),
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            self.assertEqual(b'{"success": true}', data)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_simple_GET(self):
        srv = self.loop.run_until_complete(
            self.loop.create_server(self.server.make_handler(loop=self.loop),
                                    '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request('GET', url, loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            self.assertEqual(b'{"success": true}', data)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_GET_with_query_string(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123/2?a=1&b=2'.format(port)

        def query():
            response = yield from aiohttp.request('GET', url, loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            dct = json.loads(data.decode('utf-8'))
            self.assertEqual({'success': True,
                              'args': ['a', 'b'],
                              }, dct)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_set_cookie(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/cookie/123'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request('GET', url, loop=self.loop)
            yield from response.read()
            self.assertEqual(200, response.status)
            self.assertIn('test_cookie', response.cookies)
            self.assertEqual(response.cookies['test_cookie'].value, '123')

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_get_cookie(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/get_cookie/'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request(
                'GET', url,
                cookies={'test_cookie': 'value'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            dct = json.loads(data.decode('utf-8'))
            self.assertEqual({'success': True,
                              'cookie': 'value',
                              }, dct)
        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_accept_encoding__deflate(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request(
                'GET', url, headers={'ACCEPT-ENCODING': 'deflate'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            data = yield from response.read()
            dct = json.loads(data.decode('utf-8'))
            self.assertEqual({'success': True}, dct)
            headers = response.message.headers
            enc = headers['CONTENT-ENCODING']
            self.assertEqual('deflate', enc)
        self.loop.run_until_complete(query())

    def test_accept_encoding__gzip(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        @asyncio.coroutine
        def query():
            response = yield from aiohttp.request(
                'GET', url, headers={'ACCEPT-ENCODING': 'gzip'},
                loop=self.loop)
            self.assertEqual(200, response.status)
            yield from response.read()
            # dct = json.loads(data.decode('utf-8'))
            # self.assertEqual({'success': True}, dct)
            headers = response.message.headers
            enc = headers['CONTENT-ENCODING']
            self.assertEqual('gzip', enc)
        self.loop.run_until_complete(query())

    def test_POST_without_body(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST', url,
                data='',
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(400, response.status)
            data = yield from response.read()
            j = json.loads(data.decode('utf-8'))
            self.assertEqual({"error_code": 400,
                              "error_reason": "Request has no body",
                              "error": {}}, j)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_POST_malformed_json(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST', url,
                data='{dfsdf}2',
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(400, response.status)
            data = yield from response.read()
            j = json.loads(data.decode('utf-8'))
            self.assertEqual({"error_code": 400,
                              "error_reason": "JSON body can not be decoded",
                              "error": {}}, j)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_POST_nonutf8_json(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/post/123'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST', url,
                data='{"русский": "текст"}'.encode('1251'),
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(400, response.status)
            data = yield from response.read()
            j = json.loads(data.decode('utf-8'))
            self.assertEqual({"error_code": 400,
                              "error_reason": "JSON body is not utf-8 encoded",
                              "error": {}}, j)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_status_code(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/create'.format(port)

        def query():
            response = yield from aiohttp.request(
                'POST', url,
                headers={'Content-Type': 'application/json'},
                loop=self.loop)
            self.assertEqual(201, response.status)
            data = yield from response.read()
            self.assertEqual(b'{"created": true}', data)

        self.loop.run_until_complete(query())

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_no_session(self):
        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.port = port = server_port(srv)
        url = 'http://127.0.0.1:{}/check/no/session'.format(port)

        @asyncio.coroutine
        def query():
            resp = yield from aiohttp.request('GET', url, loop=self.loop)
            self.assertEqual(200, resp.status)
            yield from resp.read()
        self.loop.run_until_complete(query())
示例#6
0
class CorsTests(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)
        self.server = RESTServer(debug=True, hostname='localhost',
                                 enable_cors=True)
        add_url = self.server.add_url

        rest = REST(self)
        add_url('GET', '/', rest.index)
        add_url('GET', '/check_origin', rest.check_origin,
                cors_options={'allow-origin': 'http://example.com/'})

    def tearDown(self):
        self.loop.close()

    @contextlib.contextmanager
    def run_server(self):
        self.assertTrue(self.server.cors_enabled)

        srv = self.loop.run_until_complete(self.loop.create_server(
            self.server.make_handler(loop=self.loop),
            '127.0.0.1', 0))
        self.assertEqual(len(srv.sockets), 1)
        sock = next(iter(srv.sockets))
        host, port = sock.getsockname()
        self.assertEqual('127.0.0.1', host)
        self.assertGreater(port, 0)
        url = 'http://{}:{}'.format(host, port)
        yield url

        srv.close()
        self.loop.run_until_complete(srv.wait_closed())

    def test_simple_GET(self):
        with self.run_server() as url:

            @asyncio.coroutine
            def query():
                headers = {
                    'ORIGIN': 'localhost',
                    }
                resp = yield from aiohttp.request('GET', url,
                                                  headers=headers,
                                                  loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                self.assertIn('ACCESS-CONTROL-ALLOW-ORIGIN', resp.headers)
                self.assertEqual(resp.headers['ACCESS-CONTROL-ALLOW-ORIGIN'],
                                 '*')

            self.loop.run_until_complete(query())

    def test_preflight(self):
        with self.run_server() as url:

            @asyncio.coroutine
            def query():
                headers = {
                    'ACCESS-CONTROL-REQUEST-METHOD': 'GET',
                    'ORIGIN': 'localhost',
                    }
                resp = yield from aiohttp.request('OPTIONS', url,
                                                  headers=headers,
                                                  loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                self.assertIn('ACCESS-CONTROL-ALLOW-ORIGIN', resp.headers)
                self.assertEqual(resp.headers['ACCESS-CONTROL-ALLOW-ORIGIN'],
                                 '*')

            self.loop.run_until_complete(query())

    def test_preflight_404(self):
        with self.run_server() as url:

            @asyncio.coroutine
            def query():
                resp = yield from aiohttp.request('OPTIONS', url,
                                                  loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 404)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-ORIGIN', resp.headers)

            self.loop.run_until_complete(query())

    def test_check_origin(self):
        with self.run_server() as url:

            @asyncio.coroutine
            def query():
                resp = yield from aiohttp.request('GET', url + '/check_origin',
                                                  headers={},
                                                  loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-ORIGIN', resp.headers)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-METHOD', resp.headers)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-HEADERS', resp.headers)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-CREDENTIALS',
                                 resp.headers)

                headers = {
                    'ORIGIN': 'localhost',
                    }
                resp = yield from aiohttp.request('GET', url + '/check_origin',
                                                  headers=headers,
                                                  loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-ORIGIN', resp.headers)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-METHOD', resp.headers)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-HEADERS', resp.headers)
                self.assertNotIn('ACCESS-CONTROL-ALLOW-CREDENTIALS',
                                 resp.headers)

                headers = {
                    'ORIGIN': 'http://example.com/',
                    }
                resp = yield from aiohttp.request('GET', url + '/check_origin',
                                                  headers=headers,
                                                  loop=self.loop)
                yield from resp.read()
                self.assertEqual(resp.status, 200)
                self.assertIn('ACCESS-CONTROL-ALLOW-ORIGIN', resp.headers)

            self.loop.run_until_complete(query())