def test_start_tls_slow_client_cancel(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        client_context = test_utils.simple_client_sslcontext()
        server_waits_on_handshake = self.loop.create_future()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            try:
                self.loop.call_soon_threadsafe(
                    server_waits_on_handshake.set_result, None)
                data = sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                pass
            finally:
                sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5, loop=self.loop)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof), *addr)

            tr.write(HELLO_MSG)

            await server_waits_on_handshake

            with self.assertRaises(asyncio.TimeoutError):
                await asyncio.wait_for(
                    self.loop.start_tls(tr, proto, client_context),
                    0.5,
                    loop=self.loop)

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
示例#2
0
    def test_start_tls_client_corrupted_ssl(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext()

        def server(sock):
            orig_sock = sock.dup()
            try:
                sock.start_tls(sslctx, server_side=True)
                sock.sendall(b'A\n')
                sock.recv_all(1)
                orig_sock.send(b'please corrupt the SSL connection')
            except ssl.SSLError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(*addr,
                                                           ssl=client_sslctx,
                                                           server_hostname='',
                                                           loop=self.loop)

            self.assertEqual(await reader.readline(), b'A\n')
            writer.write(b'B')
            with self.assertRaises(ssl.SSLError):
                await reader.readline()
            return 'OK'

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            res = self.loop.run_until_complete(client(srv.addr))

        self.assertEqual(res, 'OK')
示例#3
0
    def test_create_connection_ssl_slow_handshake(self):
        client_sslctx = test_utils.simple_client_sslcontext()

        messages = []
        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

        def server(sock):
            try:
                sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            with self.assertRaisesRegex(
                    ConnectionAbortedError,
                    r'SSL handshake.*is taking longer'):

                self.loop.run_until_complete(client(srv.addr))

        self.assertEqual(messages, [])
示例#4
0
    def test_create_connection_ssl_failed_certificate(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext(
            disable_verify=False)

        def server(sock):
            try:
                sock.start_tls(
                    sslctx,
                    server_side=True)
            except ssl.SSLError:
                pass
            except OSError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            with self.assertRaises(ssl.SSLCertVerificationError):
                self.loop.run_until_complete(client(srv.addr))
示例#5
0
    def test_create_connection_ssl_failed_certificate(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext(
            disable_verify=False)

        def server(sock):
            try:
                sock.start_tls(sslctx, server_side=True)
            except ssl.SSLError:
                pass
            except OSError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            with self.assertRaises(ssl.SSLCertVerificationError):
                self.loop.run_until_complete(client(srv.addr))
示例#6
0
    def test_create_connection_ssl_slow_handshake(self):
        client_sslctx = test_utils.simple_client_sslcontext()

        messages = []
        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

        def server(sock):
            try:
                sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                pass
            finally:
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop,
                ssl_handshake_timeout=1.0)

        with self.tcp_server(server, max_clients=1, backlog=1) as srv:

            with self.assertRaisesRegex(ConnectionAbortedError,
                                        r'SSL handshake.*is taking longer'):

                self.loop.run_until_complete(client(srv.addr))

        self.assertEqual(messages, [])
示例#7
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * 1024 * 1024 * 5

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof):
                self.on_con = on_con
                self.on_eof = on_eof
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

        async def main():
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True)

            await on_eof
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

            server.close()
            await server.wait_closed()

        on_con = self.loop.create_future()
        on_eof = self.loop.create_future()
        proto = ServerProto(on_con, on_eof)

        server = self.loop.run_until_complete(
            self.loop.create_server(
                lambda: proto, '127.0.0.1', 0))
        addr = server.sockets[0].getsockname()

        with self.tcp_client(lambda sock: client(sock, addr)):
            self.loop.run_until_complete(
                asyncio.wait_for(main(), loop=self.loop, timeout=10))
    def test_start_tls_client_reg_proto_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5, loop=self.loop)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof), *addr)

            tr.write(HELLO_MSG)
            new_tr = await self.loop.start_tls(tr, proto, client_context)

            self.assertEqual(await on_data, b'O')
            new_tr.write(HELLO_MSG)
            await on_eof

            new_tr.close()

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
示例#9
0
 async def client(addr):
     reader, writer = await asyncio.open_connection(*addr)
     writer.write(b"hello world 1!\n")
     await writer.drain()
     msgback1 = await reader.readline()
     assert writer.get_extra_info('sslcontext') is None
     await writer.start_tls(test_utils.simple_client_sslcontext())
     assert writer.get_extra_info('sslcontext') is not None
     writer.write(b"hello world 2!\n")
     await writer.drain()
     msgback2 = await reader.readline()
     writer.close()
     await writer.wait_closed()
     return msgback1, msgback2
示例#10
0
    def test_handshake_timeout(self):
        # bpo-29970: Check that a connection is aborted if handshake is not
        # completed in timeout period, instead of remaining open indefinitely
        client_sslctx = test_utils.simple_client_sslcontext()

        messages = []
        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

        server_side_aborted = False

        def server(sock):
            nonlocal server_side_aborted
            try:
                sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                server_side_aborted = True
            finally:
                sock.close()

        async def client(addr):
            await asyncio.wait_for(
                self.loop.create_connection(
                    asyncio.Protocol,
                    *addr,
                    ssl=client_sslctx,
                    server_hostname='',
                    ssl_handshake_timeout=support.SHORT_TIMEOUT),
                0.5)

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            with self.assertRaises(asyncio.TimeoutError):
                self.loop.run_until_complete(client(srv.addr))

        self.assertTrue(server_side_aborted)

        # Python issue #23197: cancelling a handshake must not raise an
        # exception or log an error, even if the handshake failed
        self.assertEqual(messages, [])

        # The 10s handshake timeout should be cancelled to free related
        # objects without really waiting for 10s
        client_sslctx = weakref.ref(client_sslctx)
        support.gc_collect()
        self.assertIsNone(client_sslctx())
示例#11
0
    def test_handshake_timeout(self):
        # bpo-29970: Check that a connection is aborted if handshake is not
        # completed in timeout period, instead of remaining open indefinitely
        client_sslctx = test_utils.simple_client_sslcontext()

        messages = []
        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

        server_side_aborted = False

        def server(sock):
            nonlocal server_side_aborted
            try:
                sock.recv_all(1024 * 1024)
            except ConnectionAbortedError:
                server_side_aborted = True
            finally:
                sock.close()

        async def client(addr):
            await asyncio.wait_for(
                self.loop.create_connection(
                    asyncio.Protocol,
                    *addr,
                    ssl=client_sslctx,
                    server_hostname='',
                    ssl_handshake_timeout=10.0),
                0.5)

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            with self.assertRaises(asyncio.TimeoutError):
                self.loop.run_until_complete(client(srv.addr))

        self.assertTrue(server_side_aborted)

        # Python issue #23197: cancelling a handshake must not raise an
        # exception or log an error, even if the handshake failed
        self.assertEqual(messages, [])

        # The 10s handshake timeout should be cancelled to free related
        # objects without really waiting for 10s
        client_sslctx = weakref.ref(client_sslctx)
        self.assertIsNone(client_sslctx())
    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
        port = support.find_unused_port()
        srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
        if is_ssl:
            if not ssl:
                self.skipTest("No ssl module")
            srv_ctx = test_utils.simple_server_sslcontext()
            cli_ctx = test_utils.simple_client_sslcontext()
        else:
            srv_ctx = None
            cli_ctx = None
        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        srv_sock.bind((support.HOST, port))
        server = self.run_loop(
            self.loop.create_server(lambda: srv_proto,
                                    sock=srv_sock,
                                    ssl=srv_ctx))
        self.reduce_receive_buffer_size(srv_sock)

        if is_ssl:
            server_hostname = support.HOST
        else:
            server_hostname = None
        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        cli_sock.connect((support.HOST, port))

        cli_proto = MySendfileProto(loop=self.loop)
        tr, pr = self.run_loop(
            self.loop.create_connection(lambda: cli_proto,
                                        sock=cli_sock,
                                        ssl=cli_ctx,
                                        server_hostname=server_hostname))
        self.reduce_send_buffer_size(cli_sock, transport=tr)

        def cleanup():
            srv_proto.transport.close()
            cli_proto.transport.close()
            self.run_loop(srv_proto.done)
            self.run_loop(cli_proto.done)

            server.close()
            self.run_loop(server.wait_closed())

        self.addCleanup(cleanup)
        return srv_proto, cli_proto
示例#13
0
    def test_start_tls_client_corrupted_ssl(self):
        self.loop.set_exception_handler(lambda loop, ctx: None)

        sslctx = test_utils.simple_server_sslcontext()
        client_sslctx = test_utils.simple_client_sslcontext()

        def server(sock):
            orig_sock = sock.dup()
            try:
                sock.start_tls(
                    sslctx,
                    server_side=True)
                sock.sendall(b'A\n')
                sock.recv_all(1)
                orig_sock.send(b'please corrupt the SSL connection')
            except ssl.SSLError:
                pass
            finally:
                orig_sock.close()
                sock.close()

        async def client(addr):
            reader, writer = await asyncio.open_connection(
                *addr,
                ssl=client_sslctx,
                server_hostname='',
                loop=self.loop)

            self.assertEqual(await reader.readline(), b'A\n')
            writer.write(b'B')
            with self.assertRaises(ssl.SSLError):
                await reader.readline()

            writer.close()
            return 'OK'

        with self.tcp_server(server,
                             max_clients=1,
                             backlog=1) as srv:

            res = self.loop.run_until_complete(client(srv.addr))

        self.assertEqual(res, 'OK')
示例#14
0
    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
        port = support.find_unused_port()
        srv_proto = MySendfileProto(loop=self.loop,
                                    close_after=close_after)
        if is_ssl:
            if not ssl:
                self.skipTest("No ssl module")
            srv_ctx = test_utils.simple_server_sslcontext()
            cli_ctx = test_utils.simple_client_sslcontext()
        else:
            srv_ctx = None
            cli_ctx = None
        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        srv_sock.bind((support.HOST, port))
        server = self.run_loop(self.loop.create_server(
            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
        self.reduce_receive_buffer_size(srv_sock)

        if is_ssl:
            server_hostname = support.HOST
        else:
            server_hostname = None
        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        cli_sock.connect((support.HOST, port))

        cli_proto = MySendfileProto(loop=self.loop)
        tr, pr = self.run_loop(self.loop.create_connection(
            lambda: cli_proto, sock=cli_sock,
            ssl=cli_ctx, server_hostname=server_hostname))
        self.reduce_send_buffer_size(cli_sock, transport=tr)

        def cleanup():
            srv_proto.transport.close()
            cli_proto.transport.close()
            self.run_loop(srv_proto.done)
            self.run_loop(cli_proto.done)

            server.close()
            self.run_loop(server.wait_closed())

        self.addCleanup(cleanup)
        return srv_proto, cli_proto
示例#15
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        if sys.platform.startswith('freebsd'):
            # bpo-35031: Some FreeBSD buildbots fail to run this test
            # as the eof was not being received by the server if the payload
            # size is not big enough. This behaviour only appears if the
            # client is using TLS1.3.
            client_context.options |= ssl.OP_NO_TLSv1_3

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr,
                proto,
                server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(lambda: proto, '127.0.0.1',
                                                   0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(main(proto, on_con, on_eof,
                                            on_con_lost),
                                       timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
示例#16
0
    def test_start_tls_client_buf_proto_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        client_con_made_calls = 0

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.sendall(b'2')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProtoFirst(asyncio.BufferedProtocol):
            def __init__(self, on_data):
                self.on_data = on_data
                self.buf = bytearray(1)

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def get_buffer(self, sizehint):
                return self.buf

            def buffer_updated(self, nsize):
                assert nsize == 1
                self.on_data.set_result(bytes(self.buf[:nsize]))

        class ClientProtoSecond(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5)

            on_data1 = self.loop.create_future()
            on_data2 = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProtoFirst(on_data1), *addr)

            tr.write(HELLO_MSG)
            new_tr = await self.loop.start_tls(tr, proto, client_context)

            self.assertEqual(await on_data1, b'O')
            new_tr.write(HELLO_MSG)

            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
            self.assertEqual(await on_data2, b'2')
            new_tr.write(HELLO_MSG)
            await on_eof

            new_tr.close()

            # connection_made() should be called only once -- when
            # we establish connection for the first time. Start TLS
            # doesn't call connection_made() on application protocols.
            self.assertEqual(client_con_made_calls, 1)

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), timeout=self.TIMEOUT))
示例#17
0
    def test_create_connection_memory_leak(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                # XXX: We assume user stores the transport in protocol
                proto.tr = tr
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof),
                *addr,
                ssl=client_context)

            self.assertEqual(await on_data, b'O')
            tr.write(HELLO_MSG)
            await on_eof

            tr.close()

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), timeout=10))

        # No garbage is left for SSL client from loop.create_connection, even
        # if user stores the SSLTransport in corresponding protocol instance
        client_context = weakref.ref(client_context)
        self.assertIsNone(client_context())
示例#18
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(
                lambda: proto, '127.0.0.1', 0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(
                    main(proto, on_con, on_eof, on_con_lost),
                    loop=self.loop, timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
示例#19
0
    def test_create_connection_memory_leak(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProto(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(proto, tr):
                # XXX: We assume user stores the transport in protocol
                proto.tr = tr
                proto.con_made_cnt += 1
                # Ensure connection_made gets called only once.
                self.assertEqual(proto.con_made_cnt, 1)

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5)

            on_data = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProto(on_data, on_eof), *addr,
                ssl=client_context)

            self.assertEqual(await on_data, b'O')
            tr.write(HELLO_MSG)
            await on_eof

            tr.close()

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr), timeout=10))

        # No garbage is left for SSL client from loop.create_connection, even
        # if user stores the SSLTransport in corresponding protocol instance
        client_context = weakref.ref(client_context)
        self.assertIsNone(client_context())
示例#20
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        if sys.platform.startswith('freebsd'):
            # bpo-35031: Some FreeBSD buildbots fail to run this test
            # as the eof was not being received by the server if the payload
            # size is not big enough. This behaviour only appears if the
            # client is using TLS1.3.
            client_context.options |= ssl.OP_NO_TLSv1_3

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(
                lambda: proto, '127.0.0.1', 0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(
                    main(proto, on_con, on_eof, on_con_lost),
                    loop=self.loop, timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
示例#21
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
        ANSWER = b'answer'

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        answer = None

        def client(sock, addr):
            nonlocal answer
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)
            answer = sock.recv_all(len(ANSWER))
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_con_lost, on_got_hello):
                self.on_con = on_con
                self.on_con_lost = on_con_lost
                self.on_got_hello = on_got_hello
                self.data = b''
                self.transport = None

            def connection_made(self, tr):
                self.transport = tr
                self.on_con.set_result(tr)

            def replace_transport(self, tr):
                self.transport = tr

            def data_received(self, data):
                self.data += data
                if len(self.data) >= len(HELLO_MSG):
                    self.on_got_hello.set_result(None)

            def connection_lost(self, exc):
                self.transport = None
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_con_lost, on_got_hello):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)
            proto.replace_transport(new_tr)

            await on_got_hello
            new_tr.write(ANSWER)

            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            on_got_hello = self.loop.create_future()
            proto = ServerProto(on_con, on_con_lost, on_got_hello)

            server = await self.loop.create_server(
                lambda: proto, '127.0.0.1', 0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(
                    main(proto, on_con, on_con_lost, on_got_hello),
                    timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()
            self.assertEqual(answer, ANSWER)

        self.loop.run_until_complete(run_main())
示例#22
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            sock.settimeout(self.TIMEOUT)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof, on_con_lost):
                self.on_con = on_con
                self.on_eof = on_eof
                self.on_con_lost = on_con_lost
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

            def connection_lost(self, exc):
                if exc is None:
                    self.on_con_lost.set_result(None)
                else:
                    self.on_con_lost.set_exception(exc)

        async def main(proto, on_con, on_eof, on_con_lost):
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr,
                proto,
                server_context,
                server_side=True,
                ssl_handshake_timeout=self.TIMEOUT)

            await on_eof
            await on_con_lost
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

        async def run_main():
            on_con = self.loop.create_future()
            on_eof = self.loop.create_future()
            on_con_lost = self.loop.create_future()
            proto = ServerProto(on_con, on_eof, on_con_lost)

            server = await self.loop.create_server(lambda: proto, '127.0.0.1',
                                                   0)
            addr = server.sockets[0].getsockname()

            with self.tcp_client(lambda sock: client(sock, addr),
                                 timeout=self.TIMEOUT):
                await asyncio.wait_for(main(proto, on_con, on_eof,
                                            on_con_lost),
                                       loop=self.loop,
                                       timeout=self.TIMEOUT)

            server.close()
            await server.wait_closed()

        self.loop.run_until_complete(run_main())
示例#23
0
    def test_start_tls_server_1(self):
        HELLO_MSG = b'1' * 1024 * 1024

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()

        def client(sock, addr):
            time.sleep(0.5)
            sock.settimeout(5)

            sock.connect(addr)
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(client_context)
            sock.sendall(HELLO_MSG)
            sock.close()

        class ServerProto(asyncio.Protocol):
            def __init__(self, on_con, on_eof):
                self.on_con = on_con
                self.on_eof = on_eof
                self.data = b''

            def connection_made(self, tr):
                self.on_con.set_result(tr)

            def data_received(self, data):
                self.data += data

            def eof_received(self):
                self.on_eof.set_result(1)

        async def main():
            tr = await on_con
            tr.write(HELLO_MSG)

            self.assertEqual(proto.data, b'')

            new_tr = await self.loop.start_tls(
                tr, proto, server_context,
                server_side=True)

            await on_eof
            self.assertEqual(proto.data, HELLO_MSG)
            new_tr.close()

            server.close()
            await server.wait_closed()

        on_con = self.loop.create_future()
        on_eof = self.loop.create_future()
        proto = ServerProto(on_con, on_eof)

        server = self.loop.run_until_complete(
            self.loop.create_server(
                lambda: proto, '127.0.0.1', 0))
        addr = server.sockets[0].getsockname()

        with self.tcp_client(lambda sock: client(sock, addr)):
            self.loop.run_until_complete(
                asyncio.wait_for(main(), loop=self.loop, timeout=10))
示例#24
0
    def test_start_tls_client_buf_proto_1(self):
        HELLO_MSG = b'1' * self.PAYLOAD_SIZE

        server_context = test_utils.simple_server_sslcontext()
        client_context = test_utils.simple_client_sslcontext()
        client_con_made_calls = 0

        def serve(sock):
            sock.settimeout(self.TIMEOUT)

            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.start_tls(server_context, server_side=True)

            sock.sendall(b'O')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.sendall(b'2')
            data = sock.recv_all(len(HELLO_MSG))
            self.assertEqual(len(data), len(HELLO_MSG))

            sock.shutdown(socket.SHUT_RDWR)
            sock.close()

        class ClientProtoFirst(asyncio.BufferedProtocol):
            def __init__(self, on_data):
                self.on_data = on_data
                self.buf = bytearray(1)

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def get_buffer(self, sizehint):
                return self.buf

            def buffer_updated(self, nsize):
                assert nsize == 1
                self.on_data.set_result(bytes(self.buf[:nsize]))

        class ClientProtoSecond(asyncio.Protocol):
            def __init__(self, on_data, on_eof):
                self.on_data = on_data
                self.on_eof = on_eof
                self.con_made_cnt = 0

            def connection_made(self, tr):
                nonlocal client_con_made_calls
                client_con_made_calls += 1

            def data_received(self, data):
                self.on_data.set_result(data)

            def eof_received(self):
                self.on_eof.set_result(True)

        async def client(addr):
            await asyncio.sleep(0.5, loop=self.loop)

            on_data1 = self.loop.create_future()
            on_data2 = self.loop.create_future()
            on_eof = self.loop.create_future()

            tr, proto = await self.loop.create_connection(
                lambda: ClientProtoFirst(on_data1), *addr)

            tr.write(HELLO_MSG)
            new_tr = await self.loop.start_tls(tr, proto, client_context)

            self.assertEqual(await on_data1, b'O')
            new_tr.write(HELLO_MSG)

            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
            self.assertEqual(await on_data2, b'2')
            new_tr.write(HELLO_MSG)
            await on_eof

            new_tr.close()

            # connection_made() should be called only once -- when
            # we establish connection for the first time. Start TLS
            # doesn't call connection_made() on application protocols.
            self.assertEqual(client_con_made_calls, 1)

        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
            self.loop.run_until_complete(
                asyncio.wait_for(client(srv.addr),
                                 loop=self.loop, timeout=self.TIMEOUT))