Ejemplo n.º 1
0
class _TestSSL(tb.SSLTestCase):

    ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
    ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')

    def test_create_unix_server_ssl_1(self):
        CNT = 0  # number of clients that were successful
        TOTAL_CNT = 25  # total number of clients that test will create
        TIMEOUT = 10.0  # timeout for this test

        A_DATA = b'A' * 1024 * 1024
        B_DATA = b'B' * 1024 * 1024

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context()

        clients = []

        async def handle_client(reader, writer):
            nonlocal CNT

            data = await reader.readexactly(len(A_DATA))
            self.assertEqual(data, A_DATA)
            writer.write(b'OK')

            data = await reader.readexactly(len(B_DATA))
            self.assertEqual(data, B_DATA)
            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])

            await writer.drain()
            writer.close()

            CNT += 1

        async def test_client(addr):
            fut = asyncio.Future(loop=self.loop)

            def prog(sock):
                try:
                    sock.starttls(client_sslctx)

                    sock.connect(addr)
                    sock.send(A_DATA)

                    data = sock.recv_all(2)
                    self.assertEqual(data, b'OK')

                    sock.send(B_DATA)
                    data = sock.recv_all(4)
                    self.assertEqual(data, b'SPAM')

                    sock.close()

                except Exception as ex:
                    self.loop.call_soon_threadsafe(lambda ex=ex: (
                        fut.cancelled() or fut.set_exception(ex)))
                else:
                    self.loop.call_soon_threadsafe(
                        lambda: (fut.cancelled() or fut.set_result(None)))

            client = self.unix_client(prog)
            client.start()
            clients.append(client)

            await fut

        async def start_server():
            extras = {}
            if self.implementation != 'asyncio' or self.PY37:
                extras = dict(ssl_handshake_timeout=10.0)

            with tempfile.TemporaryDirectory() as td:
                sock_name = os.path.join(td, 'sock')

                srv = await asyncio.start_unix_server(handle_client,
                                                      sock_name,
                                                      ssl=sslctx,
                                                      loop=self.loop,
                                                      **extras)

                try:
                    tasks = []
                    for _ in range(TOTAL_CNT):
                        tasks.append(test_client(sock_name))

                    await asyncio.wait_for(asyncio.gather(*tasks,
                                                          loop=self.loop),
                                           TIMEOUT,
                                           loop=self.loop)

                finally:
                    self.loop.call_soon(srv.close)
                    await srv.wait_closed()

        try:
            with self._silence_eof_received_warning():
                self.loop.run_until_complete(start_server())
        except asyncio.TimeoutError:
            if os.environ.get('TRAVIS_OS_NAME') == 'osx':
                # XXX: figure out why this fails on macOS on Travis
                raise unittest.SkipTest('unexplained error on Travis macOS')
            else:
                raise

        self.assertEqual(CNT, TOTAL_CNT)

        for client in clients:
            client.stop()

    def test_create_unix_connection_ssl_1(self):
        CNT = 0
        TOTAL_CNT = 25

        A_DATA = b'A' * 1024 * 1024
        B_DATA = b'B' * 1024 * 1024

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context()

        def server(sock):
            sock.starttls(sslctx, server_side=True)

            data = sock.recv_all(len(A_DATA))
            self.assertEqual(data, A_DATA)
            sock.send(b'OK')

            data = sock.recv_all(len(B_DATA))
            self.assertEqual(data, B_DATA)
            sock.send(b'SPAM')

            sock.close()

        async def client(addr):
            extras = {}
            if self.implementation != 'asyncio' or self.PY37:
                extras = dict(ssl_handshake_timeout=10.0)

            reader, writer = await asyncio.open_unix_connection(
                addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                **extras)

            writer.write(A_DATA)
            self.assertEqual(await reader.readexactly(2), b'OK')

            writer.write(B_DATA)
            self.assertEqual(await reader.readexactly(4), b'SPAM')

            nonlocal CNT
            CNT += 1

            writer.close()

        def run(coro):
            nonlocal CNT
            CNT = 0

            with self.unix_server(server,
                                  max_clients=TOTAL_CNT,
                                  backlog=TOTAL_CNT) as srv:
                tasks = []
                for _ in range(TOTAL_CNT):
                    tasks.append(coro(srv.addr))

                self.loop.run_until_complete(
                    asyncio.gather(*tasks, loop=self.loop))

            self.assertEqual(CNT, TOTAL_CNT)

        with self._silence_eof_received_warning():
            run(client)
Ejemplo n.º 2
0
class _ContextBaseTests(tb.SSLTestCase):

    ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
    ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')

    def test_task_decimal_context(self):
        async def fractions(t, precision, x, y):
            with decimal.localcontext() as ctx:
                ctx.prec = precision
                a = decimal.Decimal(x) / decimal.Decimal(y)
                await asyncio.sleep(t)
                b = decimal.Decimal(x) / decimal.Decimal(y**2)
                return a, b

        async def main():
            r1, r2 = await asyncio.gather(fractions(0.1, 3, 1, 3),
                                          fractions(0.2, 6, 1, 3))

            return r1, r2

        r1, r2 = self.loop.run_until_complete(main())

        self.assertEqual(str(r1[0]), '0.333')
        self.assertEqual(str(r1[1]), '0.111')

        self.assertEqual(str(r2[0]), '0.333333')
        self.assertEqual(str(r2[1]), '0.111111')

    def test_task_context_1(self):
        cvar = contextvars.ContextVar('cvar', default='nope')

        async def sub():
            await asyncio.sleep(0.01)
            self.assertEqual(cvar.get(), 'nope')
            cvar.set('something else')

        async def main():
            self.assertEqual(cvar.get(), 'nope')
            subtask = self.loop.create_task(sub())
            cvar.set('yes')
            self.assertEqual(cvar.get(), 'yes')
            await subtask
            self.assertEqual(cvar.get(), 'yes')

        task = self.loop.create_task(main())
        self.loop.run_until_complete(task)

    def test_task_context_2(self):
        cvar = contextvars.ContextVar('cvar', default='nope')

        async def main():
            def fut_on_done(fut):
                # This change must not pollute the context
                # of the "main()" task.
                cvar.set('something else')

            self.assertEqual(cvar.get(), 'nope')

            for j in range(2):
                fut = self.loop.create_future()
                fut.add_done_callback(fut_on_done)
                cvar.set('yes{}'.format(j))
                self.loop.call_soon(fut.set_result, None)
                await fut
                self.assertEqual(cvar.get(), 'yes{}'.format(j))

                for i in range(3):
                    # Test that task passed its context to add_done_callback:
                    cvar.set('yes{}-{}'.format(i, j))
                    await asyncio.sleep(0.001)
                    self.assertEqual(cvar.get(), 'yes{}-{}'.format(i, j))

        task = self.loop.create_task(main())
        self.loop.run_until_complete(task)

        self.assertEqual(cvar.get(), 'nope')

    def test_task_context_3(self):
        cvar = contextvars.ContextVar('cvar', default=-1)

        # Run 100 Tasks in parallel, each modifying cvar.
        async def sub(num):
            for i in range(10):
                cvar.set(num + i)
                await asyncio.sleep(random.uniform(0.001, 0.05))
                self.assertEqual(cvar.get(), num + i)

        async def main():
            tasks = []
            for i in range(100):
                task = self.loop.create_task(sub(random.randint(0, 10)))
                tasks.append(task)

            await asyncio.gather(*tasks, return_exceptions=True)

        self.loop.run_until_complete(main())

        self.assertEqual(cvar.get(), -1)

    def test_task_context_4(self):
        cvar = contextvars.ContextVar('cvar', default='nope')

        class TrackMe:
            pass

        tracked = TrackMe()
        ref = weakref.ref(tracked)

        async def sub():
            cvar.set(tracked)  # NoQA
            self.loop.call_soon(lambda: None)

        async def main():
            await self.loop.create_task(sub())
            await asyncio.sleep(0.01)

        task = self.loop.create_task(main())
        self.loop.run_until_complete(task)

        del tracked
        self.assertIsNone(ref())

    def _run_test(self, method, **switches):
        switches.setdefault('use_tcp', 'both')
        use_ssl = switches.setdefault('use_ssl', 'no') in {'yes', 'both'}
        names = ['factory']
        options = [(_Protocol, _BufferedProtocol)]
        for k, v in switches.items():
            if v == 'yes':
                options.append((True, ))
            elif v == 'no':
                options.append((False, ))
            elif v == 'both':
                options.append((True, False))
            else:
                raise ValueError(f"Illegal {k}={v}, can only be yes/no/both")
            names.append(k)

        for combo in itertools.product(*options):
            values = dict(zip(names, combo))
            with self.subTest(**values):
                cvar = contextvars.ContextVar('cvar', default='outer')
                values['proto'] = values.pop('factory')(cvar, loop=self.loop)

                async def test():
                    self.assertEqual(cvar.get(), 'outer')
                    cvar.set('inner')
                    tmp_dir = tempfile.TemporaryDirectory()
                    if use_ssl:
                        values['sslctx'] = self._create_server_ssl_context(
                            self.ONLYCERT, self.ONLYKEY)
                        values['client_sslctx'] = \
                            self._create_client_ssl_context()
                    else:
                        values['sslctx'] = values['client_sslctx'] = None

                    if values['use_tcp']:
                        values['addr'] = ('127.0.0.1', tb.find_free_port())
                        values['family'] = socket.AF_INET
                    else:
                        values['addr'] = tmp_dir.name + '/test.sock'
                        values['family'] = socket.AF_UNIX

                    try:
                        await method(cvar=cvar, **values)
                    finally:
                        tmp_dir.cleanup()

                self.loop.run_until_complete(test())

    def _run_server_test(self, method, async_sock=False, **switches):
        async def test(sslctx, client_sslctx, addr, family, **values):
            if values['use_tcp']:
                srv = await self.loop.create_server(lambda: values['proto'],
                                                    *addr,
                                                    ssl=sslctx)
            else:
                srv = await self.loop.create_unix_server(
                    lambda: values['proto'], addr, ssl=sslctx)
            s = socket.socket(family)

            if async_sock:
                s.setblocking(False)
                await self.loop.sock_connect(s, addr)
            else:
                await self.loop.run_in_executor(None, s.connect, addr)
                if values['use_ssl']:
                    values['ssl_sock'] = await self.loop.run_in_executor(
                        None, client_sslctx.wrap_socket, s)

            try:
                await method(s=s, **values)
            finally:
                if values['use_ssl']:
                    values['ssl_sock'].close()
                s.close()
                srv.close()
                await srv.wait_closed()

        return self._run_test(test, **switches)

    def test_create_server_protocol_factory_context(self):
        async def test(cvar, proto, use_tcp, family, addr, **_):
            factory_called_future = self.loop.create_future()

            def factory():
                try:
                    self.assertEqual(cvar.get(), 'inner')
                except Exception as e:
                    factory_called_future.set_exception(e)
                else:
                    factory_called_future.set_result(None)

                return proto

            if use_tcp:
                srv = await self.loop.create_server(factory, *addr)
            else:
                srv = await self.loop.create_unix_server(factory, addr)
            s = socket.socket(family)
            with s:
                s.setblocking(False)
                await self.loop.sock_connect(s, addr)

            try:
                await factory_called_future
            finally:
                srv.close()
                await proto.done
                await srv.wait_closed()

        self._run_test(test)

    def test_create_server_connection_protocol(self):
        async def test(proto, s, **_):
            inner = await proto.connection_made_fut
            self.assertEqual(inner, "inner")

            await self.loop.sock_sendall(s, b'data')
            inner = await proto.data_received_fut
            self.assertEqual(inner, "inner")

            s.shutdown(socket.SHUT_WR)
            inner = await proto.eof_received_fut
            self.assertEqual(inner, "inner")

            s.close()
            await proto.done
            self.assertEqual(proto.connection_lost_ctx, "inner")

        self._run_server_test(test, async_sock=True)

    def test_create_ssl_server_connection_protocol(self):
        async def test(cvar, proto, ssl_sock, **_):
            def resume_reading(transport):
                cvar.set("resume_reading")
                transport.resume_reading()

            try:
                inner = await proto.connection_made_fut
                self.assertEqual(inner, "inner")

                await self.loop.run_in_executor(None, ssl_sock.send, b'data')
                inner = await proto.data_received_fut
                self.assertEqual(inner, "inner")

                if self.implementation != 'asyncio':
                    # this seems to be a bug in asyncio
                    proto.data_received_fut = self.loop.create_future()
                    proto.transport.pause_reading()
                    await self.loop.run_in_executor(None, ssl_sock.send,
                                                    b'data')
                    self.loop.call_soon(resume_reading, proto.transport)
                    inner = await proto.data_received_fut
                    self.assertEqual(inner, "inner")

                    await self.loop.run_in_executor(None, ssl_sock.unwrap)
                else:
                    ssl_sock.shutdown(socket.SHUT_WR)
                inner = await proto.eof_received_fut
                self.assertEqual(inner, "inner")

                await self.loop.run_in_executor(None, ssl_sock.close)
                await proto.done
                self.assertEqual(proto.connection_lost_ctx, "inner")
            finally:
                if self.implementation == 'asyncio':
                    # mute resource warning in asyncio
                    proto.transport.close()

        self._run_server_test(test, use_ssl='yes')

    def test_create_server_manual_connection_lost(self):
        if self.implementation == 'asyncio':
            raise unittest.SkipTest('this seems to be a bug in asyncio')

        async def test(proto, cvar, **_):
            def close():
                cvar.set('closing')
                proto.transport.close()

            inner = await proto.connection_made_fut
            self.assertEqual(inner, "inner")

            self.loop.call_soon(close)

            await proto.done
            self.assertEqual(proto.connection_lost_ctx, "inner")

        self._run_server_test(test, async_sock=True)

    def test_create_ssl_server_manual_connection_lost(self):
        async def test(proto, cvar, ssl_sock, **_):
            def close():
                cvar.set('closing')
                proto.transport.close()

            inner = await proto.connection_made_fut
            self.assertEqual(inner, "inner")

            if self.implementation == 'asyncio':
                self.loop.call_soon(close)
            else:
                # asyncio doesn't have the flushing phase

                # put the incoming data on-hold
                proto.transport.pause_reading()
                # send data
                await self.loop.run_in_executor(None, ssl_sock.send, b'hello')
                # schedule a proactive transport close which will trigger
                # the flushing process to retrieve the remaining data
                self.loop.call_soon(close)
                # turn off the reading lock now (this also schedules a
                # resume operation after transport.close, therefore it
                # won't affect our test)
                proto.transport.resume_reading()

            await asyncio.sleep(0)
            await self.loop.run_in_executor(None, ssl_sock.unwrap)
            await proto.done
            self.assertEqual(proto.connection_lost_ctx, "inner")
            self.assertFalse(proto.data_received_fut.done())

        self._run_server_test(test, use_ssl='yes')

    def test_create_connection_protocol(self):
        async def test(cvar, proto, addr, sslctx, client_sslctx, family,
                       use_sock, use_ssl, use_tcp):
            ss = socket.socket(family)
            ss.bind(addr)
            ss.listen(1)

            def accept():
                sock, _ = ss.accept()
                if use_ssl:
                    sock = sslctx.wrap_socket(sock, server_side=True)
                return sock

            async def write_over():
                cvar.set("write_over")
                count = 0
                if use_ssl:
                    proto.transport.set_write_buffer_limits(high=256, low=128)
                    while not proto.transport.get_write_buffer_size():
                        proto.transport.write(b'q' * 16384)
                        count += 1
                else:
                    proto.transport.write(b'q' * 16384)
                    proto.transport.set_write_buffer_limits(high=256, low=128)
                    count += 1
                return count

            s = self.loop.run_in_executor(None, accept)

            try:
                method = ('create_connection'
                          if use_tcp else 'create_unix_connection')
                params = {}
                if use_sock:
                    cs = socket.socket(family)
                    cs.connect(addr)
                    params['sock'] = cs
                    if use_ssl:
                        params['server_hostname'] = '127.0.0.1'
                elif use_tcp:
                    params['host'] = addr[0]
                    params['port'] = addr[1]
                else:
                    params['path'] = addr
                    if use_ssl:
                        params['server_hostname'] = '127.0.0.1'
                if use_ssl:
                    params['ssl'] = client_sslctx
                await getattr(self.loop, method)(lambda: proto, **params)
                s = await s

                inner = await proto.connection_made_fut
                self.assertEqual(inner, "inner")

                await self.loop.run_in_executor(None, s.send, b'data')
                inner = await proto.data_received_fut
                self.assertEqual(inner, "inner")

                if self.implementation != 'asyncio':
                    # asyncio bug
                    count = await self.loop.create_task(write_over())
                    inner = await proto.pause_writing_fut
                    self.assertEqual(inner, "inner")

                    for i in range(count):
                        await self.loop.run_in_executor(None, s.recv, 16384)
                    inner = await proto.resume_writing_fut
                    self.assertEqual(inner, "inner")

                if use_ssl and self.implementation != 'asyncio':
                    await self.loop.run_in_executor(None, s.unwrap)
                else:
                    s.shutdown(socket.SHUT_WR)
                inner = await proto.eof_received_fut
                self.assertEqual(inner, "inner")

                s.close()
                await proto.done
                self.assertEqual(proto.connection_lost_ctx, "inner")
            finally:
                ss.close()
                proto.transport.close()

        self._run_test(test, use_sock='both', use_ssl='both')

    def test_start_tls(self):
        if self.implementation == 'asyncio':
            raise unittest.SkipTest('this seems to be a bug in asyncio')

        async def test(cvar, proto, addr, sslctx, client_sslctx, family,
                       ssl_over_ssl, use_tcp, **_):
            ss = socket.socket(family)
            ss.bind(addr)
            ss.listen(1)

            def accept():
                sock, _ = ss.accept()
                sock = sslctx.wrap_socket(sock, server_side=True)
                if ssl_over_ssl:
                    sock = _SSLSocketOverSSL(sock, sslctx, server_side=True)
                return sock

            s = self.loop.run_in_executor(None, accept)
            transport = None

            try:
                if use_tcp:
                    await self.loop.create_connection(lambda: proto, *addr)
                else:
                    await self.loop.create_unix_connection(lambda: proto, addr)
                inner = await proto.connection_made_fut
                self.assertEqual(inner, "inner")

                cvar.set('start_tls')
                transport = await self.loop.start_tls(
                    proto.transport,
                    proto,
                    client_sslctx,
                    server_hostname='127.0.0.1',
                )

                if ssl_over_ssl:
                    cvar.set('start_tls_over_tls')
                    transport = await self.loop.start_tls(
                        transport,
                        proto,
                        client_sslctx,
                        server_hostname='127.0.0.1',
                    )

                s = await s

                await self.loop.run_in_executor(None, s.send, b'data')
                inner = await proto.data_received_fut
                self.assertEqual(inner, "inner")

                await self.loop.run_in_executor(None, s.unwrap)
                inner = await proto.eof_received_fut
                self.assertEqual(inner, "inner")

                s.close()
                await proto.done
                self.assertEqual(proto.connection_lost_ctx, "inner")
            finally:
                ss.close()
                if transport:
                    transport.close()

        self._run_test(test, use_ssl='yes', ssl_over_ssl='both')

    def test_connect_accepted_socket(self):
        async def test(proto, addr, family, sslctx, client_sslctx, use_ssl,
                       **_):
            ss = socket.socket(family)
            ss.bind(addr)
            ss.listen(1)
            s = self.loop.run_in_executor(None, ss.accept)
            cs = socket.socket(family)
            cs.connect(addr)
            s, _ = await s

            try:
                if use_ssl:
                    cs = self.loop.run_in_executor(None,
                                                   client_sslctx.wrap_socket,
                                                   cs)
                    await self.loop.connect_accepted_socket(lambda: proto,
                                                            s,
                                                            ssl=sslctx)
                    cs = await cs
                else:
                    await self.loop.connect_accepted_socket(lambda: proto, s)

                inner = await proto.connection_made_fut
                self.assertEqual(inner, "inner")

                await self.loop.run_in_executor(None, cs.send, b'data')
                inner = await proto.data_received_fut
                self.assertEqual(inner, "inner")

                if use_ssl and self.implementation != 'asyncio':
                    await self.loop.run_in_executor(None, cs.unwrap)
                else:
                    cs.shutdown(socket.SHUT_WR)
                inner = await proto.eof_received_fut
                self.assertEqual(inner, "inner")

                cs.close()
                await proto.done
                self.assertEqual(proto.connection_lost_ctx, "inner")
            finally:
                proto.transport.close()
                ss.close()

        self._run_test(test, use_ssl='both')

    def test_subprocess_protocol(self):
        cvar = contextvars.ContextVar('cvar', default='outer')
        proto = _SubprocessProtocol(cvar, loop=self.loop)

        async def test():
            self.assertEqual(cvar.get(), 'outer')
            cvar.set('inner')
            await self.loop.subprocess_exec(lambda: proto,
                                            *_AsyncioTests.PROGRAM_CAT)

            try:
                inner = await proto.connection_made_fut
                self.assertEqual(inner, "inner")

                proto.transport.get_pipe_transport(0).write(b'data')
                proto.transport.get_pipe_transport(0).write_eof()
                inner = await proto.data_received_fut
                self.assertEqual(inner, "inner")

                inner = await proto.pipe_connection_lost_fut
                self.assertEqual(inner, "inner")

                inner = await proto.process_exited_fut
                if self.implementation != 'asyncio':
                    # bug in asyncio
                    self.assertEqual(inner, "inner")

                await proto.done
                if self.implementation != 'asyncio':
                    # bug in asyncio
                    self.assertEqual(proto.connection_lost_ctx, "inner")
            finally:
                proto.transport.close()

        self.loop.run_until_complete(test())

    def test_datagram_protocol(self):
        cvar = contextvars.ContextVar('cvar', default='outer')
        proto = _DatagramProtocol(cvar, loop=self.loop)
        server_addr = ('127.0.0.1', 8888)
        client_addr = ('127.0.0.1', 0)

        async def run():
            self.assertEqual(cvar.get(), 'outer')
            cvar.set('inner')

            def close():
                cvar.set('closing')
                proto.transport.close()

            try:
                await self.loop.create_datagram_endpoint(
                    lambda: proto, local_addr=server_addr)
                inner = await proto.connection_made_fut
                self.assertEqual(inner, "inner")

                s = socket.socket(socket.AF_INET, type=socket.SOCK_DGRAM)
                s.bind(client_addr)
                s.sendto(b'data', server_addr)
                inner = await proto.data_received_fut
                self.assertEqual(inner, "inner")

                self.loop.call_soon(close)
                await proto.done
                if self.implementation != 'asyncio':
                    # bug in asyncio
                    self.assertEqual(proto.connection_lost_ctx, "inner")
            finally:
                proto.transport.close()
                s.close()
                # let transports close
                await asyncio.sleep(0.1)

        self.loop.run_until_complete(run())
Ejemplo n.º 3
0
class _TestSSL(tb.SSLTestCase):

    ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
    ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')

    def test_create_server_ssl_1(self):
        CNT = 0  # number of clients that were successful
        TOTAL_CNT = 25  # total number of clients that test will create
        TIMEOUT = 10.0  # timeout for this test

        A_DATA = b'A' * 1024 * 1024
        B_DATA = b'B' * 1024 * 1024

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context()

        clients = []

        async def handle_client(reader, writer):
            nonlocal CNT

            data = await reader.readexactly(len(A_DATA))
            self.assertEqual(data, A_DATA)
            writer.write(b'OK')

            data = await reader.readexactly(len(B_DATA))
            self.assertEqual(data, B_DATA)
            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])

            await writer.drain()
            writer.close()

            CNT += 1

        async def test_client(addr):
            fut = asyncio.Future(loop=self.loop)

            def prog(sock):
                try:
                    sock.starttls(client_sslctx)
                    sock.connect(addr)
                    sock.send(A_DATA)

                    data = sock.recv_all(2)
                    self.assertEqual(data, b'OK')

                    sock.send(B_DATA)
                    data = sock.recv_all(4)
                    self.assertEqual(data, b'SPAM')

                    sock.close()

                except Exception as ex:
                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
                else:
                    self.loop.call_soon_threadsafe(fut.set_result, None)

            client = self.tcp_client(prog)
            client.start()
            clients.append(client)

            await fut

        async def start_server():
            srv = await asyncio.start_server(handle_client,
                                             '127.0.0.1',
                                             0,
                                             family=socket.AF_INET,
                                             ssl=sslctx,
                                             loop=self.loop)

            try:
                srv_socks = srv.sockets
                self.assertTrue(srv_socks)

                addr = srv_socks[0].getsockname()

                tasks = []
                for _ in range(TOTAL_CNT):
                    tasks.append(test_client(addr))

                await asyncio.wait_for(asyncio.gather(*tasks, loop=self.loop),
                                       TIMEOUT,
                                       loop=self.loop)

            finally:
                self.loop.call_soon(srv.close)
                await srv.wait_closed()

        with self._silence_eof_received_warning():
            self.loop.run_until_complete(start_server())

        self.assertEqual(CNT, TOTAL_CNT)

        for client in clients:
            client.stop()

    def test_create_connection_ssl_1(self):
        CNT = 0
        TOTAL_CNT = 25

        A_DATA = b'A' * 1024 * 1024
        B_DATA = b'B' * 1024 * 1024

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context()

        def server(sock):
            sock.starttls(sslctx, server_side=True)

            data = sock.recv_all(len(A_DATA))
            self.assertEqual(data, A_DATA)
            sock.send(b'OK')

            data = sock.recv_all(len(B_DATA))
            self.assertEqual(data, B_DATA)
            sock.send(b'SPAM')

            sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(*addr,
                                                           ssl=client_sslctx,
                                                           server_hostname='',
                                                           loop=self.loop)

            writer.write(A_DATA)
            self.assertEqual(await reader.readexactly(2), b'OK')

            writer.write(B_DATA)
            self.assertEqual(await reader.readexactly(4), b'SPAM')

            nonlocal CNT
            CNT += 1

            writer.close()

        async def client_sock(addr):
            sock = socket.socket()
            sock.connect(addr)
            reader, writer = await asyncio.open_connection(sock=sock,
                                                           ssl=client_sslctx,
                                                           server_hostname='',
                                                           loop=self.loop)

            writer.write(A_DATA)
            self.assertEqual(await reader.readexactly(2), b'OK')

            writer.write(B_DATA)
            self.assertEqual(await reader.readexactly(4), b'SPAM')

            nonlocal CNT
            CNT += 1

            writer.close()
            sock.close()

        def run(coro):
            nonlocal CNT
            CNT = 0

            with self.tcp_server(server,
                                 max_clients=TOTAL_CNT,
                                 backlog=TOTAL_CNT) as srv:
                tasks = []
                for _ in range(TOTAL_CNT):
                    tasks.append(coro(srv.addr))

                self.loop.run_until_complete(
                    asyncio.gather(*tasks, loop=self.loop))

            self.assertEqual(CNT, TOTAL_CNT)

        with self._silence_eof_received_warning():
            run(client)

        with self._silence_eof_received_warning():
            run(client_sock)
Ejemplo n.º 4
0
class _TestSSL(tb.SSLTestCase):

    ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
    ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')

    def test_create_server_ssl_1(self):
        CNT = 0  # number of clients that were successful
        TOTAL_CNT = 25  # total number of clients that test will create
        TIMEOUT = 10.0  # timeout for this test

        A_DATA = b'A' * 1024 * 1024
        B_DATA = b'B' * 1024 * 1024

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context()

        clients = []

        async def handle_client(reader, writer):
            nonlocal CNT

            data = await reader.readexactly(len(A_DATA))
            self.assertEqual(data, A_DATA)
            writer.write(b'OK')

            data = await reader.readexactly(len(B_DATA))
            self.assertEqual(data, B_DATA)
            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])

            await writer.drain()
            writer.close()

            CNT += 1

        async def test_client(addr):
            fut = asyncio.Future(loop=self.loop)

            def prog(sock):
                try:
                    sock.starttls(client_sslctx)
                    sock.connect(addr)
                    sock.send(A_DATA)

                    data = sock.recv_all(2)
                    self.assertEqual(data, b'OK')

                    sock.send(B_DATA)
                    data = sock.recv_all(4)
                    self.assertEqual(data, b'SPAM')

                    sock.close()

                except Exception as ex:
                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
                else:
                    self.loop.call_soon_threadsafe(fut.set_result, None)

            client = self.tcp_client(prog)
            client.start()
            clients.append(client)

            await fut

        async def start_server():
            extras = {}
            if self.implementation != 'asyncio' or self.PY37:
                extras = dict(ssl_handshake_timeout=10.0)

            srv = await asyncio.start_server(handle_client,
                                             '127.0.0.1',
                                             0,
                                             family=socket.AF_INET,
                                             ssl=sslctx,
                                             loop=self.loop,
                                             **extras)

            try:
                srv_socks = srv.sockets
                self.assertTrue(srv_socks)

                addr = srv_socks[0].getsockname()

                tasks = []
                for _ in range(TOTAL_CNT):
                    tasks.append(test_client(addr))

                await asyncio.wait_for(asyncio.gather(*tasks, loop=self.loop),
                                       TIMEOUT,
                                       loop=self.loop)

            finally:
                self.loop.call_soon(srv.close)
                await srv.wait_closed()

        with self._silence_eof_received_warning():
            self.loop.run_until_complete(start_server())

        self.assertEqual(CNT, TOTAL_CNT)

        for client in clients:
            client.stop()

    def test_create_connection_ssl_1(self):
        if self.implementation == 'asyncio':
            # Don't crash on asyncio errors
            self.loop.set_exception_handler(None)

        CNT = 0
        TOTAL_CNT = 25

        A_DATA = b'A' * 1024 * 1024
        B_DATA = b'B' * 1024 * 1024

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context()

        def server(sock):
            sock.starttls(sslctx, server_side=True)

            data = sock.recv_all(len(A_DATA))
            self.assertEqual(data, A_DATA)
            sock.send(b'OK')

            data = sock.recv_all(len(B_DATA))
            self.assertEqual(data, B_DATA)
            sock.send(b'SPAM')

            sock.close()

        async def client(addr):
            extras = {}
            if self.implementation != 'asyncio' or self.PY37:
                extras = dict(ssl_handshake_timeout=10.0)

            reader, writer = await asyncio.open_connection(*addr,
                                                           ssl=client_sslctx,
                                                           server_hostname='',
                                                           loop=self.loop,
                                                           **extras)

            writer.write(A_DATA)
            self.assertEqual(await reader.readexactly(2), b'OK')

            writer.write(B_DATA)
            self.assertEqual(await reader.readexactly(4), b'SPAM')

            nonlocal CNT
            CNT += 1

            writer.close()

        async def client_sock(addr):
            sock = socket.socket()
            sock.connect(addr)
            reader, writer = await asyncio.open_connection(sock=sock,
                                                           ssl=client_sslctx,
                                                           server_hostname='',
                                                           loop=self.loop)

            writer.write(A_DATA)
            self.assertEqual(await reader.readexactly(2), b'OK')

            writer.write(B_DATA)
            self.assertEqual(await reader.readexactly(4), b'SPAM')

            nonlocal CNT
            CNT += 1

            writer.close()
            sock.close()

        def run(coro):
            nonlocal CNT
            CNT = 0

            with self.tcp_server(server,
                                 max_clients=TOTAL_CNT,
                                 backlog=TOTAL_CNT) as srv:
                tasks = []
                for _ in range(TOTAL_CNT):
                    tasks.append(coro(srv.addr))

                self.loop.run_until_complete(
                    asyncio.gather(*tasks, loop=self.loop))

            self.assertEqual(CNT, TOTAL_CNT)

        with self._silence_eof_received_warning():
            run(client)

        with self._silence_eof_received_warning():
            run(client_sock)

    def test_create_connection_ssl_slow_handshake(self):
        if self.implementation == 'asyncio':
            raise unittest.SkipTest()

        client_sslctx = self._create_client_ssl_context()

        # silence error logger
        self.loop.set_exception_handler(lambda *args: None)

        def server(sock):
            try:
                sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            with self.assertRaisesRegex(ConnectionAbortedError,
                                        r'SSL handshake.*is taking longer'):

                self.loop.run_until_complete(client(srv.addr))

    def test_create_connection_ssl_failed_certificate(self):
        if self.implementation == 'asyncio':
            raise unittest.SkipTest()

        # silence error logger
        self.loop.set_exception_handler(lambda *args: None)

        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
        client_sslctx = self._create_client_ssl_context(disable_verify=False)

        def server(sock):
            try:
                sock.starttls(sslctx, server_side=True)
                sock.connect()
            except ssl.SSLError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            exc_type = ssl.SSLError
            if self.PY37:
                exc_type = ssl.SSLCertVerificationError
            with self.assertRaises(exc_type):
                self.loop.run_until_complete(client(srv.addr))

    def test_ssl_handshake_timeout(self):
        if self.implementation == 'asyncio':
            raise unittest.SkipTest()

        # bpo-29970: Check that a connection is aborted if handshake is not
        # completed in timeout period, instead of remaining open indefinitely
        client_sslctx = self._create_client_ssl_context()

        # silence error logger
        messages = []
        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

        server_side_aborted = False

        def server(sock):
            nonlocal server_side_aborted
            try:
                sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                server_side_aborted = True
            finally:
                sock.close()

        async def client(addr):
            await asyncio.wait_for(self.loop.create_connection(
                asyncio.Protocol,
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                ssl_handshake_timeout=10.0),
                                   0.5,
                                   loop=self.loop)

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            with self.assertRaises(asyncio.TimeoutError):
                self.loop.run_until_complete(client(srv.addr))

        self.assertTrue(server_side_aborted)

        # Python issue #23197: cancelling a handshake must not raise an
        # exception or log an error, even if the handshake failed
        self.assertEqual(messages, [])

    def test_ssl_connect_accepted_socket(self):
        server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY)
        if hasattr(server_context, 'check_hostname'):
            server_context.check_hostname = False
        server_context.verify_mode = ssl.CERT_NONE

        client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        if hasattr(server_context, 'check_hostname'):
            client_context.check_hostname = False
        client_context.verify_mode = ssl.CERT_NONE

        Test_UV_TCP.test_connect_accepted_socket(self, server_context,
                                                 client_context)