Esempio n. 1
0
class SessionBase(asyncio.Protocol):
    """Base class of networking sessions.

    There is no client / server distinction other than who initiated
    the connection.

    To initiate a connection to a remote server pass host, port and
    proxy to the constructor, and then call create_connection().  Each
    successful call should have a corresponding call to close().

    Alternatively if used in a with statement, the connection is made
    on entry to the block, and closed on exit from the block.
    """

    max_errors = 10

    def __init__(self, *, framer=None, loop=None):
        self.framer = framer or self.default_framer()
        self.loop = loop or asyncio.get_event_loop()
        self.logger = logging.getLogger(self.__class__.__name__)
        self.transport = None
        # Set when a connection is made
        self._address = None
        self._proxy_address = None
        # For logger.debug messages
        self.verbosity = 0
        # Cleared when the send socket is full
        self._can_send = Event()
        self._can_send.set()
        self._pm_task = None
        self._task_group = TaskGroup(self.loop)
        # Force-close a connection if a send doesn't succeed in this time
        self.max_send_delay = 60
        # Statistics.  The RPC object also keeps its own statistics.
        self.start_time = time.perf_counter()
        self.errors = 0
        self.send_count = 0
        self.send_size = 0
        self.last_send = self.start_time
        self.recv_count = 0
        self.recv_size = 0
        self.last_recv = self.start_time
        self.last_packet_received = self.start_time

    async def _limited_wait(self, secs):
        try:
            await asyncio.wait_for(self._can_send.wait(), secs)
        except asyncio.TimeoutError:
            self.abort()
            raise asyncio.TimeoutError(f'task timed out after {secs}s')

    async def _send_message(self, message):
        if not self._can_send.is_set():
            await self._limited_wait(self.max_send_delay)
        if not self.is_closing():
            framed_message = self.framer.frame(message)
            self.send_size += len(framed_message)
            self.send_count += 1
            self.last_send = time.perf_counter()
            if self.verbosity >= 4:
                self.logger.debug(f'Sending framed message {framed_message}')
            self.transport.write(framed_message)

    def _bump_errors(self):
        self.errors += 1
        if self.errors >= self.max_errors:
            # Don't await self.close() because that is self-cancelling
            self._close()

    def _close(self):
        if self.transport:
            self.transport.close()

    # asyncio framework
    def data_received(self, framed_message):
        """Called by asyncio when a message comes in."""
        self.last_packet_received = time.perf_counter()
        if self.verbosity >= 4:
            self.logger.debug(f'Received framed message {framed_message}')
        self.recv_size += len(framed_message)
        self.framer.received_bytes(framed_message)

    def pause_writing(self):
        """Transport calls when the send buffer is full."""
        if not self.is_closing():
            self._can_send.clear()
            self.transport.pause_reading()

    def resume_writing(self):
        """Transport calls when the send buffer has room."""
        if not self._can_send.is_set():
            self._can_send.set()
            self.transport.resume_reading()

    def connection_made(self, transport):
        """Called by asyncio when a connection is established.

        Derived classes overriding this method must call this first."""
        self.transport = transport
        # This would throw if called on a closed SSL transport.  Fixed
        # in asyncio in Python 3.6.1 and 3.5.4
        peer_address = transport.get_extra_info('peername')
        # If the Socks proxy was used then _address is already set to
        # the remote address
        if self._address:
            self._proxy_address = peer_address
        else:
            self._address = peer_address
        self._pm_task = self.loop.create_task(self._receive_messages())

    def connection_lost(self, exc):
        """Called by asyncio when the connection closes.

        Tear down things done in connection_made."""
        self._address = None
        self.transport = None
        self._task_group.cancel()
        if self._pm_task:
            self._pm_task.cancel()
        # Release waiting tasks
        self._can_send.set()

    # External API
    def default_framer(self):
        """Return a default framer."""
        raise NotImplementedError

    def peer_address(self):
        """Returns the peer's address (Python networking address), or None if
        no connection or an error.

        This is the result of socket.getpeername() when the connection
        was made.
        """
        return self._address

    def peer_address_str(self):
        """Returns the peer's IP address and port as a human-readable
        string."""
        if not self._address:
            return 'unknown'
        ip_addr_str, port = self._address[:2]
        if ':' in ip_addr_str:
            return f'[{ip_addr_str}]:{port}'
        else:
            return f'{ip_addr_str}:{port}'

    def is_closing(self):
        """Return True if the connection is closing."""
        return not self.transport or self.transport.is_closing()

    def abort(self):
        """Forcefully close the connection."""
        if self.transport:
            self.transport.abort()

    # TODO: replace with synchronous_close
    async def close(self, *, force_after=30):
        """Close the connection and return when closed."""
        self._close()
        if self._pm_task:
            with suppress(CancelledError):
                await asyncio.wait([self._pm_task], timeout=force_after)
                self.abort()
                await self._pm_task

    def synchronous_close(self):
        self._close()
        if self._pm_task and not self._pm_task.done():
            self._pm_task.cancel()
Esempio n. 2
0
 async def test_cancel_sets_it_done(self):
     group = TaskGroup()
     group.cancel()
     self.assertTrue(group.done.is_set())