class SessionBase(asyncio.Protocol): """Base class of networking sessions. There is no client / server distinction other than who initiated the connection. To initiate a connection to a remote server pass host, port and proxy to the constructor, and then call create_connection(). Each successful call should have a corresponding call to close(). Alternatively if used in a with statement, the connection is made on entry to the block, and closed on exit from the block. """ max_errors = 10 def __init__(self, *, framer=None, loop=None): self.framer = framer or self.default_framer() self.loop = loop or asyncio.get_event_loop() self.logger = logging.getLogger(self.__class__.__name__) self.transport = None # Set when a connection is made self._address = None self._proxy_address = None # For logger.debug messages self.verbosity = 0 # Cleared when the send socket is full self._can_send = Event() self._can_send.set() self._pm_task = None self._task_group = TaskGroup(self.loop) # Force-close a connection if a send doesn't succeed in this time self.max_send_delay = 60 # Statistics. The RPC object also keeps its own statistics. self.start_time = time.perf_counter() self.errors = 0 self.send_count = 0 self.send_size = 0 self.last_send = self.start_time self.recv_count = 0 self.recv_size = 0 self.last_recv = self.start_time self.last_packet_received = self.start_time async def _limited_wait(self, secs): try: await asyncio.wait_for(self._can_send.wait(), secs) except asyncio.TimeoutError: self.abort() raise asyncio.TimeoutError(f'task timed out after {secs}s') async def _send_message(self, message): if not self._can_send.is_set(): await self._limited_wait(self.max_send_delay) if not self.is_closing(): framed_message = self.framer.frame(message) self.send_size += len(framed_message) self.send_count += 1 self.last_send = time.perf_counter() if self.verbosity >= 4: self.logger.debug(f'Sending framed message {framed_message}') self.transport.write(framed_message) def _bump_errors(self): self.errors += 1 if self.errors >= self.max_errors: # Don't await self.close() because that is self-cancelling self._close() def _close(self): if self.transport: self.transport.close() # asyncio framework def data_received(self, framed_message): """Called by asyncio when a message comes in.""" self.last_packet_received = time.perf_counter() if self.verbosity >= 4: self.logger.debug(f'Received framed message {framed_message}') self.recv_size += len(framed_message) self.framer.received_bytes(framed_message) def pause_writing(self): """Transport calls when the send buffer is full.""" if not self.is_closing(): self._can_send.clear() self.transport.pause_reading() def resume_writing(self): """Transport calls when the send buffer has room.""" if not self._can_send.is_set(): self._can_send.set() self.transport.resume_reading() def connection_made(self, transport): """Called by asyncio when a connection is established. Derived classes overriding this method must call this first.""" self.transport = transport # This would throw if called on a closed SSL transport. Fixed # in asyncio in Python 3.6.1 and 3.5.4 peer_address = transport.get_extra_info('peername') # If the Socks proxy was used then _address is already set to # the remote address if self._address: self._proxy_address = peer_address else: self._address = peer_address self._pm_task = self.loop.create_task(self._receive_messages()) def connection_lost(self, exc): """Called by asyncio when the connection closes. Tear down things done in connection_made.""" self._address = None self.transport = None self._task_group.cancel() if self._pm_task: self._pm_task.cancel() # Release waiting tasks self._can_send.set() # External API def default_framer(self): """Return a default framer.""" raise NotImplementedError def peer_address(self): """Returns the peer's address (Python networking address), or None if no connection or an error. This is the result of socket.getpeername() when the connection was made. """ return self._address def peer_address_str(self): """Returns the peer's IP address and port as a human-readable string.""" if not self._address: return 'unknown' ip_addr_str, port = self._address[:2] if ':' in ip_addr_str: return f'[{ip_addr_str}]:{port}' else: return f'{ip_addr_str}:{port}' def is_closing(self): """Return True if the connection is closing.""" return not self.transport or self.transport.is_closing() def abort(self): """Forcefully close the connection.""" if self.transport: self.transport.abort() # TODO: replace with synchronous_close async def close(self, *, force_after=30): """Close the connection and return when closed.""" self._close() if self._pm_task: with suppress(CancelledError): await asyncio.wait([self._pm_task], timeout=force_after) self.abort() await self._pm_task def synchronous_close(self): self._close() if self._pm_task and not self._pm_task.done(): self._pm_task.cancel()
async def test_cancel_sets_it_done(self): group = TaskGroup() group.cancel() self.assertTrue(group.done.is_set())