예제 #1
0
    def test_create_connection_ssl_failed_certificate(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext(
            disable_verify=False)

        def server(sock):
            try:
                sock.start_tls(
                    sslctx,
                    server_side=True)
            except ssl.SSLError:
                pass
            except OSError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            with self.assertRaises(ssl.SSLCertVerificationError):
                self.loop.run_until_complete(client(srv.addr))
예제 #2
0
    def test_start_tls_client_corrupted_ssl(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext()

        def server(sock):
            orig_sock = sock.dup()
            try:
                sock.start_tls(sslctx, server_side=True)
                sock.sendall(b'A\n')
                sock.recv_all(1)
                orig_sock.send(b'please corrupt the SSL connection')
            except ssl.SSLError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(*addr,
                                                           ssl=client_sslctx,
                                                           server_hostname='',
                                                           loop=self.loop)

            self.assertEqual(await reader.readline(), b'A\n')
            writer.write(b'B')
            with self.assertRaises(ssl.SSLError):
                await reader.readline()
            return 'OK'

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            res = self.loop.run_until_complete(client(srv.addr))

        self.assertEqual(res, 'OK')
예제 #3
0
    def test_create_connection_ssl_failed_certificate(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext(
            disable_verify=False)

        def server(sock):
            try:
                sock.start_tls(sslctx, server_side=True)
            except ssl.SSLError:
                pass
            except OSError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            with self.assertRaises(ssl.SSLCertVerificationError):
                self.loop.run_until_complete(client(srv.addr))
예제 #4
0
        async def main():
            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
                await self.loop.start_tls(None, None, None)

            sslctx = test_utils.simple_server_sslcontext()
            with self.assertRaisesRegex(TypeError, 'is not supported'):
                await self.loop.start_tls(None, None, sslctx)
예제 #5
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * 1024 * 1024 * 5

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof):
                self.on_con = on_con
                self.on_eof = on_eof
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

        async def main():
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True)

            await on_eof
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

            server.close()
            await server.wait_closed()

        on_con = self.loop.create_future()
        on_eof = self.loop.create_future()
        proto = ServerProto(on_con, on_eof)

        server = self.loop.run_until_complete(
            self.loop.create_server(
                lambda: proto, '127.0.0.1', 0))
        addr = server.sockets[0].getsockname()

        with self.tcp_client(lambda sock: client(sock, addr)):
            self.loop.run_until_complete(
                asyncio.wait_for(main(), loop=self.loop, timeout=10))
    def test_start_tls_client_reg_proto_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5, loop=self.loop)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof), *addr)

            tr.write(HELLO_MSG)
            new_tr = await self.loop.start_tls(tr, proto, client_context)

            self.assertEqual(await on_data, b'O')
            new_tr.write(HELLO_MSG)
            await on_eof

            new_tr.close()

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
예제 #7
0
파일: test_streams.py 프로젝트: za/cpython
 async def handle_client(self, client_reader, client_writer):
     data1 = await client_reader.readline()
     client_writer.write(data1)
     await client_writer.drain()
     assert client_writer.get_extra_info('sslcontext') is None
     await client_writer.start_tls(
         test_utils.simple_server_sslcontext())
     assert client_writer.get_extra_info('sslcontext') is not None
     data2 = await client_reader.readline()
     client_writer.write(data2)
     await client_writer.drain()
     client_writer.close()
     await client_writer.wait_closed()
    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
        port = support.find_unused_port()
        srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
        if is_ssl:
            if not ssl:
                self.skipTest("No ssl module")
            srv_ctx = test_utils.simple_server_sslcontext()
            cli_ctx = test_utils.simple_client_sslcontext()
        else:
            srv_ctx = None
            cli_ctx = None
        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        srv_sock.bind((support.HOST, port))
        server = self.run_loop(
            self.loop.create_server(lambda: srv_proto,
                                    sock=srv_sock,
                                    ssl=srv_ctx))
        self.reduce_receive_buffer_size(srv_sock)

        if is_ssl:
            server_hostname = support.HOST
        else:
            server_hostname = None
        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        cli_sock.connect((support.HOST, port))

        cli_proto = MySendfileProto(loop=self.loop)
        tr, pr = self.run_loop(
            self.loop.create_connection(lambda: cli_proto,
                                        sock=cli_sock,
                                        ssl=cli_ctx,
                                        server_hostname=server_hostname))
        self.reduce_send_buffer_size(cli_sock, transport=tr)

        def cleanup():
            srv_proto.transport.close()
            cli_proto.transport.close()
            self.run_loop(srv_proto.done)
            self.run_loop(cli_proto.done)

            server.close()
            self.run_loop(server.wait_closed())

        self.addCleanup(cleanup)
        return srv_proto, cli_proto
예제 #9
0
    def test_start_tls_client_corrupted_ssl(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext()

        def server(sock):
            orig_sock = sock.dup()
            try:
                sock.start_tls(
                    sslctx,
                    server_side=True)
                sock.sendall(b'A\n')
                sock.recv_all(1)
                orig_sock.send(b'please corrupt the SSL connection')
            except ssl.SSLError:
                pass
            finally:
                orig_sock.close()
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop)

            self.assertEqual(await reader.readline(), b'A\n')
            writer.write(b'B')
            with self.assertRaises(ssl.SSLError):
                await reader.readline()

            writer.close()
            return 'OK'

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            res = self.loop.run_until_complete(client(srv.addr))

        self.assertEqual(res, 'OK')
예제 #10
0
    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
        port = support.find_unused_port()
        srv_proto = MySendfileProto(loop=self.loop,
                                    close_after=close_after)
        if is_ssl:
            if not ssl:
                self.skipTest("No ssl module")
            srv_ctx = test_utils.simple_server_sslcontext()
            cli_ctx = test_utils.simple_client_sslcontext()
        else:
            srv_ctx = None
            cli_ctx = None
        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        srv_sock.bind((support.HOST, port))
        server = self.run_loop(self.loop.create_server(
            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
        self.reduce_receive_buffer_size(srv_sock)

        if is_ssl:
            server_hostname = support.HOST
        else:
            server_hostname = None
        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        cli_sock.connect((support.HOST, port))

        cli_proto = MySendfileProto(loop=self.loop)
        tr, pr = self.run_loop(self.loop.create_connection(
            lambda: cli_proto, sock=cli_sock,
            ssl=cli_ctx, server_hostname=server_hostname))
        self.reduce_send_buffer_size(cli_sock, transport=tr)

        def cleanup():
            srv_proto.transport.close()
            cli_proto.transport.close()
            self.run_loop(srv_proto.done)
            self.run_loop(cli_proto.done)

            server.close()
            self.run_loop(server.wait_closed())

        self.addCleanup(cleanup)
        return srv_proto, cli_proto
예제 #11
0
    def test_start_tls_client_buf_proto_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        client_con_made_calls = 0

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.sendall(b'2')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProtoFirst(asyncio.BufferedProtocol):
            def __init__(self, on_data):
                self.on_data = on_data
                self.buf = bytearray(1)

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def get_buffer(self, sizehint):
                return self.buf

            def buffer_updated(self, nsize):
                assert nsize == 1
                self.on_data.set_result(bytes(self.buf[:nsize]))

        class ClientProtoSecond(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5, loop=self.loop)

            on_data1 = self.loop.create_future()
            on_data2 = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProtoFirst(on_data1), *addr)

            tr.write(HELLO_MSG)
            new_tr = await self.loop.start_tls(tr, proto, client_context)

            self.assertEqual(await on_data1, b'O')
            new_tr.write(HELLO_MSG)

            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
            self.assertEqual(await on_data2, b'2')
            new_tr.write(HELLO_MSG)
            await on_eof

            new_tr.close()

            # connection_made() should be called only once -- when
            # we establish connection for the first time. Start TLS
            # doesn't call connection_made() on application protocols.
            self.assertEqual(client_con_made_calls, 1)

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr),
                                 loop=self.loop, timeout=self.TIMEOUT))
예제 #12
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(
                lambda: proto, '127.0.0.1', 0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(
                    main(proto, on_con, on_eof, on_con_lost),
                    loop=self.loop, timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
예제 #13
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * 1024 * 1024

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            time.sleep(0.5)
            sock.settimeout(5)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof):
                self.on_con = on_con
                self.on_eof = on_eof
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

        async def main():
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True)

            await on_eof
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

            server.close()
            await server.wait_closed()

        on_con = self.loop.create_future()
        on_eof = self.loop.create_future()
        proto = ServerProto(on_con, on_eof)

        server = self.loop.run_until_complete(
            self.loop.create_server(
                lambda: proto, '127.0.0.1', 0))
        addr = server.sockets[0].getsockname()

        with self.tcp_client(lambda sock: client(sock, addr)):
            self.loop.run_until_complete(
                asyncio.wait_for(main(), loop=self.loop, timeout=10))
예제 #14
0
    def test_start_tls_client_buf_proto_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        client_con_made_calls = 0

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.sendall(b'2')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProtoFirst(asyncio.BufferedProtocol):
            def __init__(self, on_data):
                self.on_data = on_data
                self.buf = bytearray(1)

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def get_buffer(self, sizehint):
                return self.buf

            def buffer_updated(self, nsize):
                assert nsize == 1
                self.on_data.set_result(bytes(self.buf[:nsize]))

        class ClientProtoSecond(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5)

            on_data1 = self.loop.create_future()
            on_data2 = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProtoFirst(on_data1), *addr)

            tr.write(HELLO_MSG)
            new_tr = await self.loop.start_tls(tr, proto, client_context)

            self.assertEqual(await on_data1, b'O')
            new_tr.write(HELLO_MSG)

            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
            self.assertEqual(await on_data2, b'2')
            new_tr.write(HELLO_MSG)
            await on_eof

            new_tr.close()

            # connection_made() should be called only once -- when
            # we establish connection for the first time. Start TLS
            # doesn't call connection_made() on application protocols.
            self.assertEqual(client_con_made_calls, 1)

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), timeout=self.TIMEOUT))
예제 #15
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        if sys.platform.startswith('freebsd'):
            # bpo-35031: Some FreeBSD buildbots fail to run this test
            # as the eof was not being received by the server if the payload
            # size is not big enough. This behaviour only appears if the
            # client is using TLS1.3.
            client_context.options |= ssl.OP_NO_TLSv1_3

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr,
                proto,
                server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(lambda: proto, '127.0.0.1',
                                                   0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(main(proto, on_con, on_eof,
                                            on_con_lost),
                                       timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
예제 #16
0
    def test_create_connection_memory_leak(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                # XXX: We assume user stores the transport in protocol
                proto.tr = tr
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof),
                *addr,
                ssl=client_context)

            self.assertEqual(await on_data, b'O')
            tr.write(HELLO_MSG)
            await on_eof

            tr.close()

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), timeout=10))

        # No garbage is left for SSL client from loop.create_connection, even
        # if user stores the SSLTransport in corresponding protocol instance
        client_context = weakref.ref(client_context)
        self.assertIsNone(client_context())
예제 #17
0
    def test_create_connection_memory_leak(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                # XXX: We assume user stores the transport in protocol
                proto.tr = tr
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof), *addr,
                ssl=client_context)

            self.assertEqual(await on_data, b'O')
            tr.write(HELLO_MSG)
            await on_eof

            tr.close()

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), timeout=10))

        # No garbage is left for SSL client from loop.create_connection, even
        # if user stores the SSLTransport in corresponding protocol instance
        client_context = weakref.ref(client_context)
        self.assertIsNone(client_context())
예제 #18
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr,
                proto,
                server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(lambda: proto, '127.0.0.1',
                                                   0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(main(proto, on_con, on_eof,
                                            on_con_lost),
                                       loop=self.loop,
                                       timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
예제 #19
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        if sys.platform.startswith('freebsd'):
            # bpo-35031: Some FreeBSD buildbots fail to run this test
            # as the eof was not being received by the server if the payload
            # size is not big enough. This behaviour only appears if the
            # client is using TLS1.3.
            client_context.options |= ssl.OP_NO_TLSv1_3

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(
                lambda: proto, '127.0.0.1', 0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(
                    main(proto, on_con, on_eof, on_con_lost),
                    loop=self.loop, timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
예제 #20
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
        ANSWER = b'answer'

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        answer = None

        def client(sock, addr):
            nonlocal answer
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)
            answer = sock.recv_all(len(ANSWER))
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_con_lost, on_got_hello):
                self.on_con = on_con
                self.on_con_lost = on_con_lost
                self.on_got_hello = on_got_hello
                self.data = b''
                self.transport = None

            def connection_made(self, tr):
                self.transport = tr
                self.on_con.set_result(tr)

            def replace_transport(self, tr):
                self.transport = tr

            def data_received(self, data):
                self.data += data
                if len(self.data) >= len(HELLO_MSG):
                    self.on_got_hello.set_result(None)

            def connection_lost(self, exc):
                self.transport = None
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_con_lost, on_got_hello):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)
            proto.replace_transport(new_tr)

            await on_got_hello
            new_tr.write(ANSWER)

            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            on_got_hello = self.loop.create_future()
            proto = ServerProto(on_con, on_con_lost, on_got_hello)

            server = await self.loop.create_server(
                lambda: proto, '127.0.0.1', 0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(
                    main(proto, on_con, on_con_lost, on_got_hello),
                    timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()
            self.assertEqual(answer, ANSWER)

        self.loop.run_until_complete(run_main())