class _TestSSL(tb.SSLTestCase): ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') def test_create_unix_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 10.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 async def test_client(addr): fut = asyncio.Future(loop=self.loop) def prog(sock): try: sock.starttls(client_sslctx) sock.connect(addr) sock.send(A_DATA) data = sock.recv_all(2) self.assertEqual(data, b'OK') sock.send(B_DATA) data = sock.recv_all(4) self.assertEqual(data, b'SPAM') sock.close() except Exception as ex: self.loop.call_soon_threadsafe(lambda ex=ex: ( fut.cancelled() or fut.set_exception(ex))) else: self.loop.call_soon_threadsafe( lambda: (fut.cancelled() or fut.set_result(None))) client = self.unix_client(prog) client.start() clients.append(client) await fut async def start_server(): extras = {} if self.implementation != 'asyncio' or self.PY37: extras = dict(ssl_handshake_timeout=10.0) with tempfile.TemporaryDirectory() as td: sock_name = os.path.join(td, 'sock') srv = await asyncio.start_unix_server(handle_client, sock_name, ssl=sslctx, loop=self.loop, **extras) try: tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(sock_name)) await asyncio.wait_for(asyncio.gather(*tasks, loop=self.loop), TIMEOUT, loop=self.loop) finally: self.loop.call_soon(srv.close) await srv.wait_closed() try: with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) except asyncio.TimeoutError: if os.environ.get('TRAVIS_OS_NAME') == 'osx': # XXX: figure out why this fails on macOS on Travis raise unittest.SkipTest('unexplained error on Travis macOS') else: raise self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_create_unix_connection_ssl_1(self): CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls(sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') data = sock.recv_all(len(B_DATA)) self.assertEqual(data, B_DATA) sock.send(b'SPAM') sock.close() async def client(addr): extras = {} if self.implementation != 'asyncio' or self.PY37: extras = dict(ssl_handshake_timeout=10.0) reader, writer = await asyncio.open_unix_connection( addr, ssl=client_sslctx, server_hostname='', loop=self.loop, **extras) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() def run(coro): nonlocal CNT CNT = 0 with self.unix_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete( asyncio.gather(*tasks, loop=self.loop)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client)
class _ContextBaseTests(tb.SSLTestCase): ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') def test_task_decimal_context(self): async def fractions(t, precision, x, y): with decimal.localcontext() as ctx: ctx.prec = precision a = decimal.Decimal(x) / decimal.Decimal(y) await asyncio.sleep(t) b = decimal.Decimal(x) / decimal.Decimal(y**2) return a, b async def main(): r1, r2 = await asyncio.gather(fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3)) return r1, r2 r1, r2 = self.loop.run_until_complete(main()) self.assertEqual(str(r1[0]), '0.333') self.assertEqual(str(r1[1]), '0.111') self.assertEqual(str(r2[0]), '0.333333') self.assertEqual(str(r2[1]), '0.111111') def test_task_context_1(self): cvar = contextvars.ContextVar('cvar', default='nope') async def sub(): await asyncio.sleep(0.01) self.assertEqual(cvar.get(), 'nope') cvar.set('something else') async def main(): self.assertEqual(cvar.get(), 'nope') subtask = self.loop.create_task(sub()) cvar.set('yes') self.assertEqual(cvar.get(), 'yes') await subtask self.assertEqual(cvar.get(), 'yes') task = self.loop.create_task(main()) self.loop.run_until_complete(task) def test_task_context_2(self): cvar = contextvars.ContextVar('cvar', default='nope') async def main(): def fut_on_done(fut): # This change must not pollute the context # of the "main()" task. cvar.set('something else') self.assertEqual(cvar.get(), 'nope') for j in range(2): fut = self.loop.create_future() fut.add_done_callback(fut_on_done) cvar.set('yes{}'.format(j)) self.loop.call_soon(fut.set_result, None) await fut self.assertEqual(cvar.get(), 'yes{}'.format(j)) for i in range(3): # Test that task passed its context to add_done_callback: cvar.set('yes{}-{}'.format(i, j)) await asyncio.sleep(0.001) self.assertEqual(cvar.get(), 'yes{}-{}'.format(i, j)) task = self.loop.create_task(main()) self.loop.run_until_complete(task) self.assertEqual(cvar.get(), 'nope') def test_task_context_3(self): cvar = contextvars.ContextVar('cvar', default=-1) # Run 100 Tasks in parallel, each modifying cvar. async def sub(num): for i in range(10): cvar.set(num + i) await asyncio.sleep(random.uniform(0.001, 0.05)) self.assertEqual(cvar.get(), num + i) async def main(): tasks = [] for i in range(100): task = self.loop.create_task(sub(random.randint(0, 10))) tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) self.loop.run_until_complete(main()) self.assertEqual(cvar.get(), -1) def test_task_context_4(self): cvar = contextvars.ContextVar('cvar', default='nope') class TrackMe: pass tracked = TrackMe() ref = weakref.ref(tracked) async def sub(): cvar.set(tracked) # NoQA self.loop.call_soon(lambda: None) async def main(): await self.loop.create_task(sub()) await asyncio.sleep(0.01) task = self.loop.create_task(main()) self.loop.run_until_complete(task) del tracked self.assertIsNone(ref()) def _run_test(self, method, **switches): switches.setdefault('use_tcp', 'both') use_ssl = switches.setdefault('use_ssl', 'no') in {'yes', 'both'} names = ['factory'] options = [(_Protocol, _BufferedProtocol)] for k, v in switches.items(): if v == 'yes': options.append((True, )) elif v == 'no': options.append((False, )) elif v == 'both': options.append((True, False)) else: raise ValueError(f"Illegal {k}={v}, can only be yes/no/both") names.append(k) for combo in itertools.product(*options): values = dict(zip(names, combo)) with self.subTest(**values): cvar = contextvars.ContextVar('cvar', default='outer') values['proto'] = values.pop('factory')(cvar, loop=self.loop) async def test(): self.assertEqual(cvar.get(), 'outer') cvar.set('inner') tmp_dir = tempfile.TemporaryDirectory() if use_ssl: values['sslctx'] = self._create_server_ssl_context( self.ONLYCERT, self.ONLYKEY) values['client_sslctx'] = \ self._create_client_ssl_context() else: values['sslctx'] = values['client_sslctx'] = None if values['use_tcp']: values['addr'] = ('127.0.0.1', tb.find_free_port()) values['family'] = socket.AF_INET else: values['addr'] = tmp_dir.name + '/test.sock' values['family'] = socket.AF_UNIX try: await method(cvar=cvar, **values) finally: tmp_dir.cleanup() self.loop.run_until_complete(test()) def _run_server_test(self, method, async_sock=False, **switches): async def test(sslctx, client_sslctx, addr, family, **values): if values['use_tcp']: srv = await self.loop.create_server(lambda: values['proto'], *addr, ssl=sslctx) else: srv = await self.loop.create_unix_server( lambda: values['proto'], addr, ssl=sslctx) s = socket.socket(family) if async_sock: s.setblocking(False) await self.loop.sock_connect(s, addr) else: await self.loop.run_in_executor(None, s.connect, addr) if values['use_ssl']: values['ssl_sock'] = await self.loop.run_in_executor( None, client_sslctx.wrap_socket, s) try: await method(s=s, **values) finally: if values['use_ssl']: values['ssl_sock'].close() s.close() srv.close() await srv.wait_closed() return self._run_test(test, **switches) def test_create_server_protocol_factory_context(self): async def test(cvar, proto, use_tcp, family, addr, **_): factory_called_future = self.loop.create_future() def factory(): try: self.assertEqual(cvar.get(), 'inner') except Exception as e: factory_called_future.set_exception(e) else: factory_called_future.set_result(None) return proto if use_tcp: srv = await self.loop.create_server(factory, *addr) else: srv = await self.loop.create_unix_server(factory, addr) s = socket.socket(family) with s: s.setblocking(False) await self.loop.sock_connect(s, addr) try: await factory_called_future finally: srv.close() await proto.done await srv.wait_closed() self._run_test(test) def test_create_server_connection_protocol(self): async def test(proto, s, **_): inner = await proto.connection_made_fut self.assertEqual(inner, "inner") await self.loop.sock_sendall(s, b'data') inner = await proto.data_received_fut self.assertEqual(inner, "inner") s.shutdown(socket.SHUT_WR) inner = await proto.eof_received_fut self.assertEqual(inner, "inner") s.close() await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") self._run_server_test(test, async_sock=True) def test_create_ssl_server_connection_protocol(self): async def test(cvar, proto, ssl_sock, **_): def resume_reading(transport): cvar.set("resume_reading") transport.resume_reading() try: inner = await proto.connection_made_fut self.assertEqual(inner, "inner") await self.loop.run_in_executor(None, ssl_sock.send, b'data') inner = await proto.data_received_fut self.assertEqual(inner, "inner") if self.implementation != 'asyncio': # this seems to be a bug in asyncio proto.data_received_fut = self.loop.create_future() proto.transport.pause_reading() await self.loop.run_in_executor(None, ssl_sock.send, b'data') self.loop.call_soon(resume_reading, proto.transport) inner = await proto.data_received_fut self.assertEqual(inner, "inner") await self.loop.run_in_executor(None, ssl_sock.unwrap) else: ssl_sock.shutdown(socket.SHUT_WR) inner = await proto.eof_received_fut self.assertEqual(inner, "inner") await self.loop.run_in_executor(None, ssl_sock.close) await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") finally: if self.implementation == 'asyncio': # mute resource warning in asyncio proto.transport.close() self._run_server_test(test, use_ssl='yes') def test_create_server_manual_connection_lost(self): if self.implementation == 'asyncio': raise unittest.SkipTest('this seems to be a bug in asyncio') async def test(proto, cvar, **_): def close(): cvar.set('closing') proto.transport.close() inner = await proto.connection_made_fut self.assertEqual(inner, "inner") self.loop.call_soon(close) await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") self._run_server_test(test, async_sock=True) def test_create_ssl_server_manual_connection_lost(self): async def test(proto, cvar, ssl_sock, **_): def close(): cvar.set('closing') proto.transport.close() inner = await proto.connection_made_fut self.assertEqual(inner, "inner") if self.implementation == 'asyncio': self.loop.call_soon(close) else: # asyncio doesn't have the flushing phase # put the incoming data on-hold proto.transport.pause_reading() # send data await self.loop.run_in_executor(None, ssl_sock.send, b'hello') # schedule a proactive transport close which will trigger # the flushing process to retrieve the remaining data self.loop.call_soon(close) # turn off the reading lock now (this also schedules a # resume operation after transport.close, therefore it # won't affect our test) proto.transport.resume_reading() await asyncio.sleep(0) await self.loop.run_in_executor(None, ssl_sock.unwrap) await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") self.assertFalse(proto.data_received_fut.done()) self._run_server_test(test, use_ssl='yes') def test_create_connection_protocol(self): async def test(cvar, proto, addr, sslctx, client_sslctx, family, use_sock, use_ssl, use_tcp): ss = socket.socket(family) ss.bind(addr) ss.listen(1) def accept(): sock, _ = ss.accept() if use_ssl: sock = sslctx.wrap_socket(sock, server_side=True) return sock async def write_over(): cvar.set("write_over") count = 0 if use_ssl: proto.transport.set_write_buffer_limits(high=256, low=128) while not proto.transport.get_write_buffer_size(): proto.transport.write(b'q' * 16384) count += 1 else: proto.transport.write(b'q' * 16384) proto.transport.set_write_buffer_limits(high=256, low=128) count += 1 return count s = self.loop.run_in_executor(None, accept) try: method = ('create_connection' if use_tcp else 'create_unix_connection') params = {} if use_sock: cs = socket.socket(family) cs.connect(addr) params['sock'] = cs if use_ssl: params['server_hostname'] = '127.0.0.1' elif use_tcp: params['host'] = addr[0] params['port'] = addr[1] else: params['path'] = addr if use_ssl: params['server_hostname'] = '127.0.0.1' if use_ssl: params['ssl'] = client_sslctx await getattr(self.loop, method)(lambda: proto, **params) s = await s inner = await proto.connection_made_fut self.assertEqual(inner, "inner") await self.loop.run_in_executor(None, s.send, b'data') inner = await proto.data_received_fut self.assertEqual(inner, "inner") if self.implementation != 'asyncio': # asyncio bug count = await self.loop.create_task(write_over()) inner = await proto.pause_writing_fut self.assertEqual(inner, "inner") for i in range(count): await self.loop.run_in_executor(None, s.recv, 16384) inner = await proto.resume_writing_fut self.assertEqual(inner, "inner") if use_ssl and self.implementation != 'asyncio': await self.loop.run_in_executor(None, s.unwrap) else: s.shutdown(socket.SHUT_WR) inner = await proto.eof_received_fut self.assertEqual(inner, "inner") s.close() await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") finally: ss.close() proto.transport.close() self._run_test(test, use_sock='both', use_ssl='both') def test_start_tls(self): if self.implementation == 'asyncio': raise unittest.SkipTest('this seems to be a bug in asyncio') async def test(cvar, proto, addr, sslctx, client_sslctx, family, ssl_over_ssl, use_tcp, **_): ss = socket.socket(family) ss.bind(addr) ss.listen(1) def accept(): sock, _ = ss.accept() sock = sslctx.wrap_socket(sock, server_side=True) if ssl_over_ssl: sock = _SSLSocketOverSSL(sock, sslctx, server_side=True) return sock s = self.loop.run_in_executor(None, accept) transport = None try: if use_tcp: await self.loop.create_connection(lambda: proto, *addr) else: await self.loop.create_unix_connection(lambda: proto, addr) inner = await proto.connection_made_fut self.assertEqual(inner, "inner") cvar.set('start_tls') transport = await self.loop.start_tls( proto.transport, proto, client_sslctx, server_hostname='127.0.0.1', ) if ssl_over_ssl: cvar.set('start_tls_over_tls') transport = await self.loop.start_tls( transport, proto, client_sslctx, server_hostname='127.0.0.1', ) s = await s await self.loop.run_in_executor(None, s.send, b'data') inner = await proto.data_received_fut self.assertEqual(inner, "inner") await self.loop.run_in_executor(None, s.unwrap) inner = await proto.eof_received_fut self.assertEqual(inner, "inner") s.close() await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") finally: ss.close() if transport: transport.close() self._run_test(test, use_ssl='yes', ssl_over_ssl='both') def test_connect_accepted_socket(self): async def test(proto, addr, family, sslctx, client_sslctx, use_ssl, **_): ss = socket.socket(family) ss.bind(addr) ss.listen(1) s = self.loop.run_in_executor(None, ss.accept) cs = socket.socket(family) cs.connect(addr) s, _ = await s try: if use_ssl: cs = self.loop.run_in_executor(None, client_sslctx.wrap_socket, cs) await self.loop.connect_accepted_socket(lambda: proto, s, ssl=sslctx) cs = await cs else: await self.loop.connect_accepted_socket(lambda: proto, s) inner = await proto.connection_made_fut self.assertEqual(inner, "inner") await self.loop.run_in_executor(None, cs.send, b'data') inner = await proto.data_received_fut self.assertEqual(inner, "inner") if use_ssl and self.implementation != 'asyncio': await self.loop.run_in_executor(None, cs.unwrap) else: cs.shutdown(socket.SHUT_WR) inner = await proto.eof_received_fut self.assertEqual(inner, "inner") cs.close() await proto.done self.assertEqual(proto.connection_lost_ctx, "inner") finally: proto.transport.close() ss.close() self._run_test(test, use_ssl='both') def test_subprocess_protocol(self): cvar = contextvars.ContextVar('cvar', default='outer') proto = _SubprocessProtocol(cvar, loop=self.loop) async def test(): self.assertEqual(cvar.get(), 'outer') cvar.set('inner') await self.loop.subprocess_exec(lambda: proto, *_AsyncioTests.PROGRAM_CAT) try: inner = await proto.connection_made_fut self.assertEqual(inner, "inner") proto.transport.get_pipe_transport(0).write(b'data') proto.transport.get_pipe_transport(0).write_eof() inner = await proto.data_received_fut self.assertEqual(inner, "inner") inner = await proto.pipe_connection_lost_fut self.assertEqual(inner, "inner") inner = await proto.process_exited_fut if self.implementation != 'asyncio': # bug in asyncio self.assertEqual(inner, "inner") await proto.done if self.implementation != 'asyncio': # bug in asyncio self.assertEqual(proto.connection_lost_ctx, "inner") finally: proto.transport.close() self.loop.run_until_complete(test()) def test_datagram_protocol(self): cvar = contextvars.ContextVar('cvar', default='outer') proto = _DatagramProtocol(cvar, loop=self.loop) server_addr = ('127.0.0.1', 8888) client_addr = ('127.0.0.1', 0) async def run(): self.assertEqual(cvar.get(), 'outer') cvar.set('inner') def close(): cvar.set('closing') proto.transport.close() try: await self.loop.create_datagram_endpoint( lambda: proto, local_addr=server_addr) inner = await proto.connection_made_fut self.assertEqual(inner, "inner") s = socket.socket(socket.AF_INET, type=socket.SOCK_DGRAM) s.bind(client_addr) s.sendto(b'data', server_addr) inner = await proto.data_received_fut self.assertEqual(inner, "inner") self.loop.call_soon(close) await proto.done if self.implementation != 'asyncio': # bug in asyncio self.assertEqual(proto.connection_lost_ctx, "inner") finally: proto.transport.close() s.close() # let transports close await asyncio.sleep(0.1) self.loop.run_until_complete(run())
class _TestSSL(tb.SSLTestCase): ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 10.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 async def test_client(addr): fut = asyncio.Future(loop=self.loop) def prog(sock): try: sock.starttls(client_sslctx) sock.connect(addr) sock.send(A_DATA) data = sock.recv_all(2) self.assertEqual(data, b'OK') sock.send(B_DATA) data = sock.recv_all(4) self.assertEqual(data, b'SPAM') sock.close() except Exception as ex: self.loop.call_soon_threadsafe(fut.set_exception, ex) else: self.loop.call_soon_threadsafe(fut.set_result, None) client = self.tcp_client(prog) client.start() clients.append(client) await fut async def start_server(): srv = await asyncio.start_server(handle_client, '127.0.0.1', 0, family=socket.AF_INET, ssl=sslctx, loop=self.loop) try: srv_socks = srv.sockets self.assertTrue(srv_socks) addr = srv_socks[0].getsockname() tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(addr)) await asyncio.wait_for(asyncio.gather(*tasks, loop=self.loop), TIMEOUT, loop=self.loop) finally: self.loop.call_soon(srv.close) await srv.wait_closed() with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_create_connection_ssl_1(self): CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls(sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') data = sock.recv_all(len(B_DATA)) self.assertEqual(data, B_DATA) sock.send(b'SPAM') sock.close() async def client(addr): reader, writer = await asyncio.open_connection(*addr, ssl=client_sslctx, server_hostname='', loop=self.loop) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() async def client_sock(addr): sock = socket.socket() sock.connect(addr) reader, writer = await asyncio.open_connection(sock=sock, ssl=client_sslctx, server_hostname='', loop=self.loop) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() sock.close() def run(coro): nonlocal CNT CNT = 0 with self.tcp_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete( asyncio.gather(*tasks, loop=self.loop)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client) with self._silence_eof_received_warning(): run(client_sock)
class _TestSSL(tb.SSLTestCase): ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 10.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 async def test_client(addr): fut = asyncio.Future(loop=self.loop) def prog(sock): try: sock.starttls(client_sslctx) sock.connect(addr) sock.send(A_DATA) data = sock.recv_all(2) self.assertEqual(data, b'OK') sock.send(B_DATA) data = sock.recv_all(4) self.assertEqual(data, b'SPAM') sock.close() except Exception as ex: self.loop.call_soon_threadsafe(fut.set_exception, ex) else: self.loop.call_soon_threadsafe(fut.set_result, None) client = self.tcp_client(prog) client.start() clients.append(client) await fut async def start_server(): extras = {} if self.implementation != 'asyncio' or self.PY37: extras = dict(ssl_handshake_timeout=10.0) srv = await asyncio.start_server(handle_client, '127.0.0.1', 0, family=socket.AF_INET, ssl=sslctx, loop=self.loop, **extras) try: srv_socks = srv.sockets self.assertTrue(srv_socks) addr = srv_socks[0].getsockname() tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(addr)) await asyncio.wait_for(asyncio.gather(*tasks, loop=self.loop), TIMEOUT, loop=self.loop) finally: self.loop.call_soon(srv.close) await srv.wait_closed() with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_create_connection_ssl_1(self): if self.implementation == 'asyncio': # Don't crash on asyncio errors self.loop.set_exception_handler(None) CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls(sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') data = sock.recv_all(len(B_DATA)) self.assertEqual(data, B_DATA) sock.send(b'SPAM') sock.close() async def client(addr): extras = {} if self.implementation != 'asyncio' or self.PY37: extras = dict(ssl_handshake_timeout=10.0) reader, writer = await asyncio.open_connection(*addr, ssl=client_sslctx, server_hostname='', loop=self.loop, **extras) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() async def client_sock(addr): sock = socket.socket() sock.connect(addr) reader, writer = await asyncio.open_connection(sock=sock, ssl=client_sslctx, server_hostname='', loop=self.loop) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() sock.close() def run(coro): nonlocal CNT CNT = 0 with self.tcp_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete( asyncio.gather(*tasks, loop=self.loop)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client) with self._silence_eof_received_warning(): run(client_sock) def test_create_connection_ssl_slow_handshake(self): if self.implementation == 'asyncio': raise unittest.SkipTest() client_sslctx = self._create_client_ssl_context() # silence error logger self.loop.set_exception_handler(lambda *args: None) def server(sock): try: sock.recv_all(1024 * 1024) except ConnectionAbortedError: pass finally: sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', loop=self.loop, ssl_handshake_timeout=1.0) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaisesRegex(ConnectionAbortedError, r'SSL handshake.*is taking longer'): self.loop.run_until_complete(client(srv.addr)) def test_create_connection_ssl_failed_certificate(self): if self.implementation == 'asyncio': raise unittest.SkipTest() # silence error logger self.loop.set_exception_handler(lambda *args: None) sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context(disable_verify=False) def server(sock): try: sock.starttls(sslctx, server_side=True) sock.connect() except ssl.SSLError: pass finally: sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', loop=self.loop, ssl_handshake_timeout=1.0) with self.tcp_server(server, max_clients=1, backlog=1) as srv: exc_type = ssl.SSLError if self.PY37: exc_type = ssl.SSLCertVerificationError with self.assertRaises(exc_type): self.loop.run_until_complete(client(srv.addr)) def test_ssl_handshake_timeout(self): if self.implementation == 'asyncio': raise unittest.SkipTest() # bpo-29970: Check that a connection is aborted if handshake is not # completed in timeout period, instead of remaining open indefinitely client_sslctx = self._create_client_ssl_context() # silence error logger messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) server_side_aborted = False def server(sock): nonlocal server_side_aborted try: sock.recv_all(1024 * 1024) except ConnectionAbortedError: server_side_aborted = True finally: sock.close() async def client(addr): await asyncio.wait_for(self.loop.create_connection( asyncio.Protocol, *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=10.0), 0.5, loop=self.loop) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(client(srv.addr)) self.assertTrue(server_side_aborted) # Python issue #23197: cancelling a handshake must not raise an # exception or log an error, even if the handshake failed self.assertEqual(messages, []) def test_ssl_connect_accepted_socket(self): server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY) if hasattr(server_context, 'check_hostname'): server_context.check_hostname = False server_context.verify_mode = ssl.CERT_NONE client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if hasattr(server_context, 'check_hostname'): client_context.check_hostname = False client_context.verify_mode = ssl.CERT_NONE Test_UV_TCP.test_connect_accepted_socket(self, server_context, client_context)