Example #1
0
class BasePeerPool(BaseService, AsyncIterable[BasePeer]):
    """
    PeerPool maintains connections to up-to max_peers on a given network.
    """
    _report_interval = 60
    _peer_boot_timeout = DEFAULT_PEER_BOOT_TIMEOUT
    _event_bus: EndpointAPI = None

    def __init__(
        self,
        privkey: datatypes.PrivateKey,
        context: BasePeerContext,
        max_peers: int = DEFAULT_MAX_PEERS,
        token: CancelToken = None,
        event_bus: EndpointAPI = None,
    ) -> None:
        super().__init__(token)

        self.privkey = privkey
        self.max_peers = max_peers
        self.context = context

        self.connected_nodes: Dict[SessionAPI, BasePeer] = {}

        self._subscribers: List[PeerSubscriber] = []
        self._event_bus = event_bus

        # Restricts the number of concurrent connection attempts can be made
        self._connection_attempt_lock = asyncio.BoundedSemaphore(
            MAX_CONCURRENT_CONNECTION_ATTEMPTS)

        # Ensure we can only have a single concurrent handshake in flight per remote
        self._handshake_locks = ResourceLock()

        self.peer_backends = self.setup_peer_backends()
        self.connection_tracker = self.setup_connection_tracker()

    @cached_property
    def has_event_bus(self) -> bool:
        return self._event_bus is not None

    def get_event_bus(self) -> EndpointAPI:
        if self._event_bus is None:
            raise AttributeError("No event bus configured for this peer pool")
        return self._event_bus

    def setup_connection_tracker(self) -> BaseConnectionTracker:
        """
        Return an instance of `p2p.tracking.connection.BaseConnectionTracker`
        which will be used to track peer connection failures.
        """
        return NoopConnectionTracker()

    def setup_peer_backends(self) -> Tuple[BasePeerBackend, ...]:
        if self.has_event_bus:
            return (
                DiscoveryPeerBackend(self.get_event_bus()),
                BootnodesPeerBackend(self.get_event_bus()),
            )
        else:
            self.logger.warning("No event bus configured for peer pool.")
            return ()

    async def _add_peers_from_backend(self, backend: BasePeerBackend) -> None:
        available_slots = self.max_peers - len(self)

        try:
            connected_remotes = {
                peer.remote
                for peer in self.connected_nodes.values()
            }
            candidates = await self.wait(
                backend.get_peer_candidates(
                    num_requested=available_slots,
                    connected_remotes=connected_remotes,
                ),
                timeout=REQUEST_PEER_CANDIDATE_TIMEOUT,
            )
        except asyncio.TimeoutError:
            self.logger.warning("PeerCandidateRequest timed out to backend %s",
                                backend)
            return
        else:
            self.logger.debug2(
                "Got candidates from backend %s (%s)",
                backend,
                candidates,
            )
            if candidates:
                await self.connect_to_nodes(iter(candidates))

    async def maybe_connect_more_peers(self) -> None:
        rate_limiter = TokenBucket(
            rate=1 / PEER_CONNECT_INTERVAL,
            capacity=MAX_SEQUENTIAL_PEER_CONNECT,
        )

        while self.is_operational:
            if self.is_full:
                await self.sleep(PEER_CONNECT_INTERVAL)
                continue

            await self.wait(rate_limiter.take())

            try:
                await asyncio.gather(*(self._add_peers_from_backend(backend)
                                       for backend in self.peer_backends))
            except OperationCancelled:
                break
            except asyncio.CancelledError:
                # no need to log this exception, this is expected
                raise
            except Exception:
                self.logger.exception(
                    "unexpected error during peer connection")
                # Continue trying to connect to peers, even if there was a
                # surprising failure during one of the attempts.
                continue

    def __len__(self) -> int:
        return len(self.connected_nodes)

    @property
    @abstractmethod
    def peer_factory_class(self) -> Type[BasePeerFactory]:
        ...

    def get_peer_factory(self) -> BasePeerFactory:
        return self.peer_factory_class(
            privkey=self.privkey,
            context=self.context,
            event_bus=self._event_bus,
            token=self.cancel_token,
        )

    @property
    def is_full(self) -> bool:
        return len(self) >= self.max_peers

    def is_valid_connection_candidate(self, candidate: NodeAPI) -> bool:
        # connect to no more then 2 nodes with the same IP
        nodes_by_ip = groupby(
            operator.attrgetter('remote.address.ip'),
            self.connected_nodes.keys(),
        )
        matching_ip_nodes = nodes_by_ip.get(candidate.address.ip, [])
        return len(matching_ip_nodes) <= 2

    def subscribe(self, subscriber: PeerSubscriber) -> None:
        self._subscribers.append(subscriber)
        for peer in self.connected_nodes.values():
            subscriber.register_peer(peer)
            peer.add_subscriber(subscriber)

    def unsubscribe(self, subscriber: PeerSubscriber) -> None:
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)
        for peer in self.connected_nodes.values():
            peer.remove_subscriber(subscriber)

    async def start_peer(self, peer: BasePeer) -> None:
        self.run_child_service(peer.connection)
        await self.wait(peer.connection.events.started.wait(), timeout=1)

        self.run_child_service(peer)
        await self.wait(peer.events.started.wait(), timeout=1)
        await self.wait(peer.ready.wait(), timeout=1)
        if peer.is_operational:
            self._add_peer(peer, ())
        else:
            self.logger.debug(
                "%s was cancelled immediately, not adding to pool", peer)

        try:
            await self.wait(peer.boot_manager.events.finished.wait(),
                            timeout=self._peer_boot_timeout)
        except asyncio.TimeoutError as err:
            self.logger.debug('Timout waiting for peer to boot: %s', err)
            await peer.disconnect(DisconnectReason.TIMEOUT)
            return
        except HandshakeFailure as err:
            self.connection_tracker.record_failure(peer.remote, err)
            raise
        else:
            if not peer.is_operational:
                self.logger.debug(
                    '%s disconnected during boot-up, dropped from pool', peer)

    def _add_peer(self, peer: BasePeer, msgs: Tuple[PeerMessage, ...]) -> None:
        """Add the given peer to the pool.

        Appart from adding it to our list of connected nodes and adding each of our subscriber's
        to the peer, we also add the given messages to our subscriber's queues.
        """
        self.logger.info('Adding %s to pool', peer)
        self.connected_nodes[peer.session] = peer
        peer.add_finished_callback(self._peer_finished)
        for subscriber in self._subscribers:
            subscriber.register_peer(peer)
            peer.add_subscriber(subscriber)
            for msg in msgs:
                subscriber.add_msg(msg)

    async def _run(self) -> None:
        # FIXME: PeerPool should probably no longer be a BaseService, but for now we're keeping it
        # so in order to ensure we cancel all peers when we terminate.
        if self.has_event_bus:
            self.run_daemon_task(self.maybe_connect_more_peers())

        self.run_daemon_task(self._periodically_report_stats())
        await self.cancel_token.wait()

    async def stop_all_peers(self) -> None:
        self.logger.info("Stopping all peers ...")
        peers = self.connected_nodes.values()
        disconnections = (peer.disconnect(DisconnectReason.CLIENT_QUITTING)
                          for peer in peers if peer.is_running)
        await asyncio.gather(*disconnections)

    async def _cleanup(self) -> None:
        await self.stop_all_peers()

    async def connect(self, remote: NodeAPI) -> BasePeer:
        """
        Connect to the given remote and return a Peer instance when successful.
        Returns None if the remote is unreachable, times out or is useless.
        """
        if not self._handshake_locks.is_locked(remote):
            self.logger.warning(
                "Tried to connect to %s without acquiring lock first!", remote)

        if any(peer.remote == remote
               for peer in self.connected_nodes.values()):
            self.logger.debug2("Skipping %s; already connected to it", remote)
            raise IneligiblePeer(f"Already connected to {remote}")

        try:
            should_connect = await self.wait(
                self.connection_tracker.should_connect_to(remote),
                timeout=1,
            )
        except asyncio.TimeoutError:
            self.logger.warning(
                "ConnectionTracker.should_connect_to request timed out.")
            raise

        if not should_connect:
            raise IneligiblePeer(
                f"Peer database rejected peer candidate: {remote}")

        try:
            self.logger.debug2("Connecting to %s...", remote)
            return await self.wait(
                self.get_peer_factory().handshake(remote),
                timeout=HANDSHAKE_TIMEOUT,
            )
        except OperationCancelled:
            # Pass it on to instruct our main loop to stop.
            raise
        except BadAckMessage:
            # This is kept separate from the
            # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't
            # silencing an error in our authentication code.
            self.logger.error('Got bad auth ack from %r', remote)
            # dump the full stacktrace in the debug logs
            self.logger.debug('Got bad auth ack from %r',
                              remote,
                              exc_info=True)
            raise
        except MalformedMessage:
            # This is kept separate from the
            # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't
            # silencing an error in how we decode messages during handshake.
            self.logger.error(
                'Got malformed response from %r during handshake', remote)
            # dump the full stacktrace in the debug logs
            self.logger.debug('Got malformed response from %r',
                              remote,
                              exc_info=True)
            raise
        except HandshakeFailure as e:
            self.logger.debug("Could not complete handshake with %r: %s",
                              remote, repr(e))
            self.connection_tracker.record_failure(remote, e)
            raise
        except COMMON_PEER_CONNECTION_EXCEPTIONS as e:
            self.logger.debug("Could not complete handshake with %r: %s",
                              remote, repr(e))
            raise

    async def connect_to_nodes(self, nodes: Iterator[NodeAPI]) -> None:
        # create an generator for the nodes
        nodes_iter = iter(nodes)

        while True:
            if self.is_full or not self.is_operational:
                return

            # only attempt to connect to up to the maximum number of available
            # peer slots that are open.
            available_peer_slots = self.max_peers - len(self)
            batch_size = clamp(1, 10, available_peer_slots)
            batch = tuple(take(batch_size, nodes_iter))

            # There are no more *known* nodes to connect to.
            if not batch:
                return

            self.logger.debug(
                'Initiating %d peer connection attempts with %d open peer slots',
                len(batch),
                available_peer_slots,
            )
            # Try to connect to the peers concurrently.
            await asyncio.gather(
                *(self.connect_to_node(node) for node in batch),
                loop=self.get_event_loop(),
            )

    def lock_node_for_handshake(self, node: NodeAPI) -> asyncio.Lock:
        return self._handshake_locks.lock(node)

    def is_connected_to_node(self, node: NodeAPI) -> bool:
        return any(node == session.remote
                   for session in self.connected_nodes.keys())

    async def connect_to_node(self, node: NodeAPI) -> None:
        """
        Connect to a single node quietly aborting if the peer pool is full or
        shutting down, or one of the expected peer level exceptions is raised
        while connecting.
        """
        if self.is_full or not self.is_operational:
            return

        async with self.lock_node_for_handshake(node):
            if self.is_connected_to_node(node):
                self.logger.debug(
                    "Aborting outbound connection attempt to %s. Already connected!",
                    node,
                )
                return

            try:
                async with self._connection_attempt_lock:
                    peer = await self.connect(node)
            except ALLOWED_PEER_CONNECTION_EXCEPTIONS:
                return

            # Check again to see if we have *become* full since the previous
            # check.
            if self.is_full:
                self.logger.debug(
                    "Successfully connected to %s but peer pool is full.  Disconnecting.",
                    peer,
                )
                await peer.disconnect(DisconnectReason.TOO_MANY_PEERS)
                return
            elif not self.is_operational:
                self.logger.debug(
                    "Successfully connected to %s but peer pool is closing.  Disconnecting.",
                    peer,
                )
                await peer.disconnect(DisconnectReason.CLIENT_QUITTING)
                return
            else:
                await self.start_peer(peer)

    def _peer_finished(self, peer: AsyncioServiceAPI) -> None:
        """
        Remove the given peer from our list of connected nodes.
        This is passed as a callback to be called when a peer finishes.
        """
        peer = cast(BasePeer, peer)
        if peer.session in self.connected_nodes:
            self.logger.debug(
                "Removing %s from pool: local_reason=%s remote_reason=%s",
                peer,
                peer.p2p_api.local_disconnect_reason,
                peer.p2p_api.remote_disconnect_reason,
            )
            self.connected_nodes.pop(peer.session)
        else:
            self.logger.warning(
                "%s finished but was not found in connected_nodes (%s)",
                peer,
                tuple(self.connected_nodes.values()),
            )

        for subscriber in self._subscribers:
            subscriber.deregister_peer(peer)

    async def __aiter__(self) -> AsyncIterator[BasePeer]:
        for peer in tuple(self.connected_nodes.values()):
            # Yield control to ensure we process any disconnection requests from peers. Otherwise
            # we could return peers that should have been disconnected already.
            await asyncio.sleep(0)
            if peer.is_operational and not peer.is_closing:
                yield peer

    async def _periodically_report_stats(self) -> None:
        while self.is_operational:
            inbound_peers = len([
                peer for peer in self.connected_nodes.values() if peer.inbound
            ])
            self.logger.info("Connected peers: %d inbound, %d outbound",
                             inbound_peers,
                             (len(self.connected_nodes) - inbound_peers))
            subscribers = len(self._subscribers)
            if subscribers:
                longest_queue = max(self._subscribers,
                                    key=operator.attrgetter('queue_size'))
                self.logger.debug(
                    "Peer subscribers: %d, longest queue: %s(%d)", subscribers,
                    longest_queue.__class__.__name__, longest_queue.queue_size)

            self.logger.debug("== Peer details == ")
            # make a copy, because we might edit the original during iteration
            peers = tuple(self.connected_nodes.values())
            for peer in peers:
                if not peer.is_running:
                    self.logger.debug(
                        "%s is no longer alive but had not been removed from pool",
                        peer)
                    continue
                self.logger.debug(
                    "%s: uptime=%s, received_msgs=%d",
                    peer,
                    humanize_seconds(peer.uptime),
                    peer.received_msgs_count,
                )
                self.logger.debug(
                    "client_version_string='%s'",
                    peer.p2p_api.safe_client_version_string,
                )
                if not hasattr(peer, "eth_api"):
                    self.logger.warning("Huh? %s doesn't have an eth API",
                                        peer)
                for line in peer.get_extra_stats():
                    self.logger.debug("    %s", line)
            self.logger.debug("== End peer details == ")
            await self.sleep(self._report_interval)
Example #2
0
class Multiplexer(CancellableMixin, MultiplexerAPI):
    logger = get_extended_debug_logger('p2p.multiplexer.Multiplexer')

    _multiplex_token: CancelToken

    _transport: TransportAPI
    _msg_counts: DefaultDict[Type[CommandAPI], int]

    _protocol_locks: ResourceLock
    _protocol_queues: Dict[Type[ProtocolAPI], 'asyncio.Queue[Tuple[CommandAPI, Payload]]']

    def __init__(self,
                 transport: TransportAPI,
                 base_protocol: BaseP2PProtocol,
                 protocols: Sequence[ProtocolAPI],
                 token: CancelToken = None,
                 max_queue_size: int = 4096) -> None:
        if token is None:
            loop = None
        else:
            loop = token.loop
        base_token = CancelToken(f'multiplexer[{transport.remote}]', loop=loop)

        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)

        self._transport = transport
        # the base `p2p` protocol instance.
        self._base_protocol = base_protocol

        # the sub-protocol instances
        self._protocols = protocols

        # Lock to ensure that multiple call sites cannot concurrently stream
        # messages.
        self._multiplex_lock = asyncio.Lock()

        # Lock management on a per-protocol basis to ensure we only have one
        # stream consumer for each protocol.
        self._protocol_locks = ResourceLock()

        # Each protocol gets a queue where messages for the individual protocol
        # are placed when streamed from the transport
        self._protocol_queues = {
            type(protocol): asyncio.Queue(max_queue_size)
            for protocol
            in self.get_protocols()
        }

        self._msg_counts = collections.defaultdict(int)

    def __str__(self) -> str:
        protocol_infos = ','.join(tuple(
            f"{proto.name}:{proto.version}"
            for proto
            in self.get_protocols()
        ))
        return f"Multiplexer[{protocol_infos}]"

    def __repr__(self) -> str:
        return f"<{self}>"

    #
    # Transport API
    #
    def get_transport(self) -> TransportAPI:
        return self._transport

    #
    # Message Counts
    #
    def get_total_msg_count(self) -> int:
        return sum(self._msg_counts.values())

    #
    # Proxy Transport methods
    #
    @property
    def remote(self) -> NodeAPI:
        return self._transport.remote

    @property
    def is_closing(self) -> bool:
        return self._transport.is_closing

    def close(self) -> None:
        self._transport.close()
        self.cancel_token.trigger()

    #
    # Protocol API
    #
    def has_protocol(self, protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]]) -> bool:
        try:
            if isinstance(protocol_identifier, Protocol):
                self.get_protocol_by_type(type(protocol_identifier))
                return True
            elif isinstance(protocol_identifier, type):
                self.get_protocol_by_type(protocol_identifier)
                return True
            else:
                raise TypeError(
                    f"Unsupported protocol value: {protocol_identifier} of type "
                    f"{type(protocol_identifier)}"
                )
        except UnknownProtocol:
            return False

    def get_protocol_by_type(self, protocol_class: Type[TProtocol]) -> TProtocol:
        if issubclass(protocol_class, BaseP2PProtocol):
            return cast(TProtocol, self._base_protocol)

        for protocol in self._protocols:
            if type(protocol) is protocol_class:
                return cast(TProtocol, protocol)
        raise UnknownProtocol(f"No protocol found with type {protocol_class}")

    def get_base_protocol(self) -> BaseP2PProtocol:
        return self._base_protocol

    def get_protocols(self) -> Tuple[ProtocolAPI, ...]:
        return tuple(cons(self._base_protocol, self._protocols))

    #
    # Streaming API
    #
    def stream_protocol_messages(self,
                                 protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]],
                                 ) -> AsyncIterator[Tuple[CommandAPI, Payload]]:
        """
        Stream the messages for the specified protocol.
        """
        if isinstance(protocol_identifier, Protocol):
            protocol_class = type(protocol_identifier)
        elif isinstance(protocol_identifier, type) and issubclass(protocol_identifier, Protocol):
            protocol_class = protocol_identifier
        else:
            raise TypeError("Unknown protocol identifier: {protocol}")

        if not self.has_protocol(protocol_class):
            raise UnknownProtocol(f"Unknown protocol '{protocol_class}'")

        if self._protocol_locks.is_locked(protocol_class):
            raise Exception(f"Streaming lock for {protocol_class} is not free.")
        elif not self._multiplex_lock.locked():
            raise Exception("Not multiplexed.")

        # Mostly a sanity check but this ensures we do better than accidentally
        # raising an attribute error in whatever race conditions or edge cases
        # potentially make the `_multiplex_token` unavailable.
        if not hasattr(self, '_multiplex_token'):
            raise Exception("No cancel token found for multiplexing.")

        # We do the wait_iter here so that the call sites in the handshakers
        # that use this don't need to be aware of cancellation tokens.
        return self.wait_iter(
            self._stream_protocol_messages(protocol_class),
            token=self._multiplex_token,
        )

    async def _stream_protocol_messages(self,
                                        protocol_class: Type[Protocol],
                                        ) -> AsyncIterator[Tuple[CommandAPI, Payload]]:
        """
        Stream the messages for the specified protocol.
        """
        async with self._protocol_locks.lock(protocol_class):
            msg_queue = self._protocol_queues[protocol_class]
            if not hasattr(self, '_multiplex_token'):
                raise Exception("Multiplexer is not multiplexed")
            token = self._multiplex_token

            while not self.is_closing and not token.triggered:
                try:
                    # We use an optimistic strategy here of using
                    # `get_nowait()` to reduce the number of times we yield to
                    # the event loop.  Since this is an async generator it will
                    # yield to the loop each time it returns a value so we
                    # don't have to worry about this blocking other processes.
                    yield msg_queue.get_nowait()
                except asyncio.QueueEmpty:
                    yield await self.wait(msg_queue.get(), token=token)

    #
    # Message reading and streaming API
    #
    @asynccontextmanager
    async def multiplex(self) -> AsyncIterator[None]:
        """
        API for running the background task that feeds individual protocol
        queues that allows each individual protocol to stream only its own
        messages.
        """
        # We generate a new token for each time the multiplexer is used to
        # multiplex so that we can reliably cancel it without requiring the
        # master token for the multiplexer to be cancelled.
        async with self._multiplex_lock:
            multiplex_token = CancelToken(
                'multiplex',
                loop=self.cancel_token.loop,
            ).chain(self.cancel_token)

            stop = asyncio.Event()
            self._multiplex_token = multiplex_token
            fut = asyncio.ensure_future(self._do_multiplexing(stop, multiplex_token))
            # wait for the multiplexing to actually start
            try:
                yield
            finally:
                #
                # Prevent corruption of the Transport:
                #
                # On exit the `Transport` can be in a few states:
                #
                # 1. IDLE: between reads
                # 2. HEADER: waiting to read the bytes for the message header
                # 3. BODY: already read the header, waiting for body bytes.
                #
                # In the IDLE case we get a clean shutdown by simply signaling
                # to `_do_multiplexing` that it should exit which is done with
                # an `asyncio.EVent`
                #
                # In the HEADER case we can issue a hard stop either via
                # cancellation or the cancel token.  The read *should* be
                # interrupted without consuming any bytes from the
                # `StreamReader`.
                #
                # In the BODY case we want to give the `Transport.recv` call a
                # moment to finish reading the body after which it will be IDLE
                # and will exit via the IDLE exit mechanism.
                stop.set()

                # If the transport is waiting to read the body of the message
                # we want to give it a moment to finish that read.  Otherwise
                # this leaves the transport in a corrupt state.
                if self._transport.read_state is TransportState.BODY:
                    try:
                        await asyncio.wait_for(fut, timeout=1)
                    except asyncio.TimeoutError:
                        pass

                # After giving the transport an opportunity to shutdown
                # cleanly, we issue a hard shutdown, first via cancellation and
                # then via the cancel token.  This should only end up
                # corrupting the transport in the case where the header data is
                # read but the body data takes too long to arrive which should
                # be very rare and would likely indicate a malicious or broken
                # peer.
                if fut.done():
                    fut.result()
                else:
                    fut.cancel()
                    try:
                        await fut
                    except asyncio.CancelledError:
                        pass

                multiplex_token.trigger()
                del self._multiplex_token

    async def _do_multiplexing(self, stop: asyncio.Event, token: CancelToken) -> None:
        """
        Background task that reads messages from the transport and feeds them
        into individual queues for each of the protocols.
        """
        msg_stream = self.wait_iter(stream_transport_messages(
            self._transport,
            self._base_protocol,
            *self._protocols,
            token=token,
        ), token=token)
        async for protocol, cmd, msg in msg_stream:
            # track total number of messages received for each command type.
            self._msg_counts[type(cmd)] += 1

            queue = self._protocol_queues[type(protocol)]
            try:
                # We must use `put_nowait` here to ensure that in the event
                # that a single protocol queue is full that we don't block
                # other protocol messages getting through.
                queue.put_nowait((cmd, msg))
            except asyncio.QueueFull:
                self.logger.error(
                    (
                        "Multiplexing queue for protocol '%s' full. "
                        "discarding message: %s"
                    ),
                    protocol,
                    cmd,
                )

            if stop.is_set():
                break
Example #3
0
class BasePeerPool(Service, AsyncIterable[BasePeer]):
    """
    PeerPool maintains connections to up-to max_peers on a given network.
    """
    _report_stats_interval = 60
    _report_metrics_interval = 15  # for influxdb/grafana reporting
    _peer_boot_timeout = DEFAULT_PEER_BOOT_TIMEOUT
    _event_bus: EndpointAPI = None
    # If set to True, exceptions raised when connecting to a remote peer will be logged (DEBUG if
    # it's in ALLOWED_PEER_CONNECTION_EXCEPTIONS or ERROR if not) and suppressed. Should only be
    # set to False in tests if we want to ensure a connection is successful.
    _suppress_connection_exceptions: bool = True

    _handshake_locks: ResourceLock[NodeAPI]
    peer_reporter_registry_class: Type[
        PeerReporterRegistry[Any]] = PeerReporterRegistry[BasePeer]

    def __init__(
        self,
        privkey: keys.PrivateKey,
        context: BasePeerContext,
        max_peers: int = DEFAULT_MAX_PEERS,
        event_bus: EndpointAPI = None,
        metrics_registry: MetricsRegistry = None,
    ) -> None:
        self.logger = get_logger(self.__module__ + '.' +
                                 self.__class__.__name__)

        self.privkey = privkey
        self.max_peers = max_peers
        self.context = context

        self.connected_nodes: Dict[SessionAPI, BasePeer] = {}

        self._subscribers: List[PeerSubscriber] = []
        self._event_bus = event_bus

        if metrics_registry is None:
            # Initialize with a MetricsRegistry from pyformance as p2p can not depend on Trinity
            # This is so that we don't need to pass a MetricsRegistry in tests and mocked pools.
            metrics_registry = MetricsRegistry()
        self._active_peer_counter = metrics_registry.counter(
            'trinity.p2p/peers.counter')
        self._peer_reporter_registry = self.get_peer_reporter_registry(
            metrics_registry)

        # Restricts the number of concurrent connection attempts can be made
        self._connection_attempt_lock = asyncio.BoundedSemaphore(
            MAX_CONCURRENT_CONNECTION_ATTEMPTS)

        # Ensure we can only have a single concurrent handshake in flight per remote
        self._handshake_locks = ResourceLock()

        self.peer_backends = self.setup_peer_backends()
        self.connection_tracker = self.setup_connection_tracker()

    @cached_property
    def has_event_bus(self) -> bool:
        return self._event_bus is not None

    def get_event_bus(self) -> EndpointAPI:
        if self._event_bus is None:
            raise AttributeError("No event bus configured for this peer pool")
        return self._event_bus

    def setup_connection_tracker(self) -> BaseConnectionTracker:
        """
        Return an instance of `p2p.tracking.connection.BaseConnectionTracker`
        which will be used to track peer connection failures.
        """
        return NoopConnectionTracker()

    def setup_peer_backends(self) -> Tuple[BasePeerBackend, ...]:
        if self.has_event_bus:
            return (
                DiscoveryPeerBackend(self.get_event_bus()),
                BootnodesPeerBackend(self.get_event_bus()),
            )
        else:
            self.logger.warning("No event bus configured for peer pool.")
            return ()

    @property
    def available_slots(self) -> int:
        return self.max_peers - len(self)

    async def _add_peers_from_backend(
            self, backend: BasePeerBackend,
            should_skip_fn: Callable[[Tuple[NodeID, ...], NodeAPI],
                                     bool]) -> int:

        connected_node_ids = {
            peer.remote.id
            for peer in self.connected_nodes.values()
        }
        # Only ask for random bootnodes if we're not connected to any peers.
        if isinstance(backend,
                      BootnodesPeerBackend) and len(connected_node_ids) > 0:
            return 0

        try:
            blacklisted_node_ids = await asyncio.wait_for(
                self.connection_tracker.get_blacklisted(), timeout=1)
        except asyncio.TimeoutError:
            self.logger.warning(
                "Timed out getting blacklisted peers from connection tracker, pausing peer "
                "addition until we can get that.")
            return 0

        skip_list = connected_node_ids.union(blacklisted_node_ids)
        should_skip_fn = functools.partial(should_skip_fn, skip_list)
        # Request a large batch on every iteration as that will effectively push DiscoveryService
        # to trigger new peer lookups in order to find enough compatible peers to fulfill our
        # request. There's probably some room for experimentation here in order to find an optimal
        # value.
        max_candidates = self.available_slots * OVER_PROVISION_MISSING_PEERS
        try:
            candidates = await asyncio.wait_for(
                backend.get_peer_candidates(
                    max_candidates=max_candidates,
                    should_skip_fn=should_skip_fn,
                ),
                timeout=REQUEST_PEER_CANDIDATE_TIMEOUT,
            )
        except asyncio.TimeoutError:
            self.logger.warning("PeerCandidateRequest timed out to backend %s",
                                backend)
            return 0
        else:
            self.logger.debug("Got %d peer candidates from backend %s",
                              len(candidates), backend)
            if candidates:
                await self.connect_to_nodes(candidates)
            return len(candidates)

    def __len__(self) -> int:
        return len(self.connected_nodes)

    @property
    @abstractmethod
    def peer_factory_class(self) -> Type[BasePeerFactory]:
        ...

    def get_peer_factory(self) -> BasePeerFactory:
        return self.peer_factory_class(
            privkey=self.privkey,
            context=self.context,
            event_bus=self._event_bus,
        )

    def get_peer_reporter_registry(
            self, metrics_registry: MetricsRegistry
    ) -> PeerReporterRegistry[BasePeer]:
        return self.peer_reporter_registry_class(metrics_registry)

    async def get_protocol_capabilities(self) -> Tuple[Tuple[str, int], ...]:
        return tuple(
            (handshaker.protocol_class.name, handshaker.protocol_class.version)
            for handshaker in await self.get_peer_factory().get_handshakers())

    @property
    def is_full(self) -> bool:
        return len(self) >= self.max_peers

    def is_valid_connection_candidate(self, candidate: NodeAPI) -> bool:
        # connect to no more then 2 nodes with the same IP
        nodes_by_ip = groupby(
            operator.attrgetter('remote.address.ip'),
            self.connected_nodes.keys(),
        )
        matching_ip_nodes = nodes_by_ip.get(candidate.address.ip, [])
        return len(matching_ip_nodes) <= 2

    def subscribe(self, subscriber: PeerSubscriber) -> None:
        self._subscribers.append(subscriber)
        for peer in self.connected_nodes.values():
            subscriber.register_peer(peer)
            peer.add_subscriber(subscriber)

    def unsubscribe(self, subscriber: PeerSubscriber) -> None:
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)
        for peer in self.connected_nodes.values():
            peer.remove_subscriber(subscriber)

    async def _start_peer(self, peer: BasePeer) -> None:
        self.manager.run_child_service(peer.connection)
        await asyncio.wait_for(peer.connection.get_manager().wait_started(),
                               timeout=PEER_READY_TIMEOUT)
        await peer.connection.run_peer(peer)

    async def add_inbound_peer(self, peer: BasePeer) -> None:
        try:
            await self._start_peer(peer)
        except asyncio.TimeoutError as err:
            self.logger.debug('Timeout waiting for %s to start: %s', peer, err)
            return

        if self.is_connected_to_node(peer.remote):
            self.logger.debug(
                "Aborting inbound connection attempt by %s. Already connected!",
                peer)
            await peer.disconnect(DisconnectReason.ALREADY_CONNECTED)
            return

        if self.is_full:
            self.logger.debug(
                "Aborting inbound connection attempt by %s. PeerPool is full",
                peer)
            await peer.disconnect(DisconnectReason.TOO_MANY_PEERS)
            return
        elif not self.is_valid_connection_candidate(peer.remote):
            self.logger.debug(
                "Aborting inbound connection attempt by %s. Not a valid candidate",
                peer)
            # XXX: Currently, is_valid_connection_candidate() only checks that we're connected
            # to 2 or less nodes with the same IP, so TOO_MANY_PEERS is what makes the most sense
            # here.
            await peer.disconnect(DisconnectReason.TOO_MANY_PEERS)
            return

        total_peers = len(self)
        inbound_peer_count = len(
            tuple(peer for peer in self.connected_nodes.values()
                  if peer.inbound))
        if total_peers > 1 and inbound_peer_count / total_peers > DIAL_IN_OUT_RATIO:
            self.logger.debug(
                "Aborting inbound connection attempt by %s. Too many inbound peers",
                peer)
            await peer.disconnect(DisconnectReason.TOO_MANY_PEERS)
            return

        await self._add_peer_and_bootstrap(peer)

    async def add_outbound_peer(self, peer: BasePeer) -> None:
        try:
            await self._start_peer(peer)
        except asyncio.TimeoutError as err:
            self.logger.debug('Timeout waiting for %s to start: %s', peer, err)
            return

        # Check again to see if we have *become* full since the previous check.
        if self.is_full:
            self.logger.debug(
                "Successfully connected to %s but peer pool is full.  Disconnecting.",
                peer,
            )
            await peer.disconnect(DisconnectReason.TOO_MANY_PEERS)
            return
        elif not self.manager.is_running:
            self.logger.debug(
                "Successfully connected to %s but peer pool is closing.  Disconnecting.",
                peer,
            )
            await peer.disconnect(DisconnectReason.CLIENT_QUITTING)
            return

        await self._add_peer_and_bootstrap(peer)

    async def _add_peer_and_bootstrap(self, peer: BasePeer) -> None:
        # Add the peer to ourselves, ensuring it has subscribers before we start the protocol
        # streams.
        self._add_peer(peer)
        peer.start_protocol_streams()

        try:
            await asyncio.wait_for(
                peer.boot_manager.get_manager().wait_finished(),
                timeout=self._peer_boot_timeout)
        except asyncio.TimeoutError as err:
            self.logger.debug('Timout waiting for peer to boot: %s', err)
            await peer.disconnect(DisconnectReason.TIMEOUT)
            return

    def _add_peer(self, peer: BasePeer) -> None:
        """Add the given peer to the pool.

        Add the peer to our list of connected nodes and add each of our subscribers
        to the peer.
        """
        if len(self) < QUIET_PEER_POOL_SIZE:
            logger = self.logger.info
        else:
            logger = self.logger.debug
        logger("Adding %s to pool", peer)
        self.connected_nodes[peer.session] = peer
        self._active_peer_counter.inc()
        self._peer_reporter_registry.assign_peer_reporter(peer)
        peer.add_finished_callback(self._peer_finished)
        for subscriber in self._subscribers:
            subscriber.register_peer(peer)
            peer.add_subscriber(subscriber)

    @abstractmethod
    async def maybe_connect_more_peers(self) -> None:
        ...

    async def run(self) -> None:
        if self.has_event_bus:
            self.manager.run_daemon_task(self.maybe_connect_more_peers)

        self.manager.run_daemon_task(self._periodically_report_stats)
        self.manager.run_daemon_task(self._periodically_report_metrics)
        await self.manager.wait_finished()

    async def connect(self, remote: NodeAPI) -> BasePeer:
        """
        Connect to the given remote and return a Peer instance when successful.
        Returns None if the remote is unreachable, times out or is useless.
        """
        if not self._handshake_locks.is_locked(remote):
            self.logger.warning(
                "Tried to connect to %s without acquiring lock first!", remote)

        if any(peer.remote == remote
               for peer in self.connected_nodes.values()):
            self.logger.warning(
                "Attempted to connect to peer we are already connected to: %s",
                remote)
            raise IneligiblePeer(f"Already connected to {remote}")

        try:
            self.logger.debug("Connecting to %s...", remote)
            task_name = f'PeerPool/Handshake/{remote}'
            task = create_task(self.get_peer_factory().handshake(remote),
                               task_name)
            return await asyncio.wait_for(task, timeout=HANDSHAKE_TIMEOUT)
        except BadAckMessage:
            # This is kept separate from the
            # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't
            # silencing an error in our authentication code.
            self.logger.error('Got bad auth ack from %r', remote)
            # dump the full stacktrace in the debug logs
            self.logger.debug('Got bad auth ack from %r',
                              remote,
                              exc_info=True)
            raise
        except MalformedMessage as e:
            # This is kept separate from the
            # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't
            # silencing an error in how we decode messages during handshake.
            self.logger.error(
                'Got malformed response from %r during handshake', remote)
            # dump the full stacktrace in the debug logs
            self.logger.debug('Got malformed response from %r',
                              remote,
                              exc_info=True)
            self.connection_tracker.record_failure(remote, e)
            raise
        except HandshakeFailure as e:
            self.logger.debug("Could not complete handshake with %r: %s",
                              remote, repr(e))
            self.connection_tracker.record_failure(remote, e)
            raise
        except COMMON_PEER_CONNECTION_EXCEPTIONS as e:
            self.logger.debug("Could not complete handshake with %r: %s",
                              remote, repr(e))
            self.connection_tracker.record_failure(remote, e)
            raise
        finally:
            # XXX: We sometimes get an exception here but the task is finished and with
            # an exception as well. No idea how that happens but if we don't consume the
            # task's exception, asyncio complains.
            if not task.done():
                task.cancel()
            try:
                await task
            except (Exception, asyncio.CancelledError):
                pass

    async def connect_to_nodes(self, nodes: Sequence[NodeAPI]) -> None:
        # create an generator for the nodes
        nodes_iter = iter(nodes)

        while True:
            if self.is_full or not self.manager.is_running:
                return

            # only attempt to connect to up to the maximum number of available
            # peer slots that are open.
            batch_size = clamp(1, 10, self.available_slots)
            batch = tuple(take(batch_size, nodes_iter))

            # There are no more *known* nodes to connect to.
            if not batch:
                return

            self.logger.debug(
                'Initiating %d peer connection attempts with %d open peer slots',
                len(batch),
                self.available_slots,
            )
            # Try to connect to the peers concurrently.
            await asyncio.gather(*(self.connect_to_node(node)
                                   for node in batch))

    def lock_node_for_handshake(self,
                                node: NodeAPI) -> AsyncContextManager[None]:
        return self._handshake_locks.lock(node)

    def is_connected_to_node(self, node: NodeAPI) -> bool:
        return any(node == session.remote
                   for session in self.connected_nodes.keys())

    async def connect_to_node(self, node: NodeAPI) -> None:
        """
        Connect to a single node quietly aborting if the peer pool is full or
        shutting down, or one of the expected peer level exceptions is raised
        while connecting.
        """
        if self.is_full or not self.manager.is_running:
            self.logger.warning(
                "Asked to connect to node when either full or not operational")
            return

        if self._handshake_locks.is_locked(node):
            self.logger.info(
                "Asked to connect to %s when handshake lock is already locked, will wait",
                node)

        async with self.lock_node_for_handshake(node):
            if self.is_connected_to_node(node):
                self.logger.debug(
                    "Aborting outbound connection attempt to %s. Already connected!",
                    node,
                )
                return

            try:
                async with self._connection_attempt_lock:
                    peer = await self.connect(node)
            except ALLOWED_PEER_CONNECTION_EXCEPTIONS:
                if self._suppress_connection_exceptions:
                    # These are all logged in self.connect(), so we simply return.
                    return
                else:
                    raise
            except Exception:
                self.logger.exception("Unexpected error connecting to %s",
                                      node)
                if self._suppress_connection_exceptions:
                    return
                else:
                    raise

            await self.add_outbound_peer(peer)

    def _peer_finished(self, peer: BasePeer) -> None:
        """
        Remove the given peer from our list of connected nodes.
        This is passed as a callback to be called when a peer finishes.
        """
        if peer.session in self.connected_nodes:
            self.logger.debug(
                "Removing %s from pool: local_reason=%s remote_reason=%s",
                peer,
                peer.local_disconnect_reason,
                peer.remote_disconnect_reason,
            )
            self.connected_nodes.pop(peer.session)
        else:
            self.logger.warning(
                "%s finished but was not found in connected_nodes (%s)",
                peer,
                tuple(self.connected_nodes.values()),
            )

        for subscriber in self._subscribers:
            subscriber.deregister_peer(peer)

        self._active_peer_counter.dec()
        self._peer_reporter_registry.unassign_peer_reporter(peer)

    async def __aiter__(self) -> AsyncIterator[BasePeer]:
        for peer in tuple(self.connected_nodes.values()):
            # Yield control to ensure we process any disconnection requests from peers. Otherwise
            # we could return peers that should have been disconnected already.
            await asyncio.sleep(0)
            if peer.is_alive:
                yield peer

    async def _periodically_report_metrics(self) -> None:
        while self.manager.is_running:
            self._peer_reporter_registry.trigger_peer_reports()
            await asyncio.sleep(self._report_metrics_interval)

    async def _periodically_report_stats(self) -> None:
        while self.manager.is_running:
            inbound_peers = len([
                peer for peer in self.connected_nodes.values() if peer.inbound
            ])
            self.logger.info("Connected peers: %d inbound, %d outbound",
                             inbound_peers,
                             (len(self.connected_nodes) - inbound_peers))
            subscribers = len(self._subscribers)
            if subscribers:
                longest_queue = max(self._subscribers,
                                    key=operator.attrgetter('queue_size'))
                self.logger.debug(
                    "Peer subscribers: %d, longest queue: %s(%d)", subscribers,
                    longest_queue.__class__.__name__, longest_queue.queue_size)

            self.logger.debug("== Peer details == ")
            # make a copy, because we might edit the original during iteration
            peers = tuple(self.connected_nodes.values())
            for peer in peers:
                if not peer.get_manager().is_running:
                    self.logger.debug(
                        "%s is no longer alive but had not been removed from pool",
                        peer)
                    continue
                self.logger.debug(
                    "%s: uptime=%s, received_msgs=%d",
                    peer,
                    humanize_seconds(peer.uptime),
                    peer.received_msgs_count,
                )
                self.logger.debug(
                    "client_version_string='%s'",
                    peer.safe_client_version_string,
                )
                try:
                    for line in peer.get_extra_stats():
                        self.logger.debug("    %s", line)
                except (UnknownAPI, PeerConnectionLost) as exc:
                    self.logger.debug("    Failure during stats lookup: %r",
                                      exc)

            self.logger.debug("== End peer details == ")
            await asyncio.sleep(self._report_stats_interval)
Example #4
0
class Multiplexer(CancellableMixin, MultiplexerAPI):
    logger = cast(ExtendedDebugLogger,
                  logging.getLogger('p2p.multiplexer.Multiplexer'))

    _multiplex_token: CancelToken

    _transport: TransportAPI
    _msg_counts: DefaultDict[Type[CommandAPI], int]

    _protocol_locks: ResourceLock
    _protocol_queues: Dict[Type[ProtocolAPI],
                           'asyncio.Queue[Tuple[CommandAPI, Payload]]']

    def __init__(self,
                 transport: TransportAPI,
                 base_protocol: P2PProtocol,
                 protocols: Sequence[ProtocolAPI],
                 token: CancelToken = None,
                 max_queue_size: int = 4096) -> None:
        if token is None:
            loop = None
        else:
            loop = token.loop
        base_token = CancelToken(f'multiplexer[{transport.remote}]', loop=loop)

        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)

        self._transport = transport
        # the base `p2p` protocol instance.
        self._base_protocol = base_protocol

        # the sub-protocol instances
        self._protocols = protocols

        # Lock to ensure that multiple call sites cannot concurrently stream
        # messages.
        self._multiplex_lock = asyncio.Lock()

        # Lock management on a per-protocol basis to ensure we only have one
        # stream consumer for each protocol.
        self._protocol_locks = ResourceLock()

        # Each protocol gets a queue where messages for the individual protocol
        # are placed when streamed from the transport
        self._protocol_queues = {
            type(protocol): asyncio.Queue(max_queue_size)
            for protocol in self.get_protocols()
        }

        self._msg_counts = collections.defaultdict(int)

    def __str__(self) -> str:
        protocol_infos = ','.join(
            tuple(f"{proto.name}:{proto.version}"
                  for proto in self.get_protocols()))
        return f"Multiplexer[{protocol_infos}]"

    def __repr__(self) -> str:
        return f"<{self}>"

    #
    # Transport API
    #
    def get_transport(self) -> TransportAPI:
        return self._transport

    #
    # Message Counts
    #
    def get_total_msg_count(self) -> int:
        return sum(self._msg_counts.values())

    #
    # Proxy Transport methods
    #
    @property
    def remote(self) -> NodeAPI:
        return self._transport.remote

    @property
    def is_closing(self) -> bool:
        return self._transport.is_closing

    def close(self) -> None:
        self._transport.close()
        self.cancel_token.trigger()

    #
    # Protocol API
    #
    def has_protocol(
            self, protocol_identifier: Union[ProtocolAPI,
                                             Type[ProtocolAPI]]) -> bool:
        try:
            if isinstance(protocol_identifier, Protocol):
                self.get_protocol_by_type(type(protocol_identifier))
                return True
            elif isinstance(protocol_identifier, type):
                self.get_protocol_by_type(protocol_identifier)
                return True
            else:
                raise TypeError(
                    f"Unsupported protocol value: {protocol_identifier} of type "
                    f"{type(protocol_identifier)}")
        except UnknownProtocol:
            return False

    def get_protocol_by_type(self,
                             protocol_class: Type[TProtocol]) -> TProtocol:
        if protocol_class is P2PProtocol:
            return cast(TProtocol, self._base_protocol)

        for protocol in self._protocols:
            if type(protocol) is protocol_class:
                return cast(TProtocol, protocol)
        raise UnknownProtocol(f"No protocol found with type {protocol_class}")

    def get_base_protocol(self) -> P2PProtocol:
        return self._base_protocol

    def get_protocols(self) -> Tuple[ProtocolAPI, ...]:
        return tuple(cons(self._base_protocol, self._protocols))

    #
    # Streaming API
    #
    async def stream_protocol_messages(
        self,
        protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]],
    ) -> AsyncIterator[Tuple[CommandAPI, Payload]]:
        """
        Stream the messages for the specified protocol.
        """
        if isinstance(protocol_identifier, Protocol):
            protocol_class = type(protocol_identifier)
        elif isinstance(protocol_identifier, type) and issubclass(
                protocol_identifier, Protocol):
            protocol_class = protocol_identifier
        else:
            raise TypeError("Unknown protocol identifier: {protocol}")

        if not self.has_protocol(protocol_class):
            raise UnknownProtocol(f"Unknown protocol '{protocol_class}'")

        if self._protocol_locks.is_locked(protocol_class):
            raise Exception(
                f"Streaming lock for {protocol_class} is not free.")

        async with self._protocol_locks.lock(protocol_class):
            msg_queue = self._protocol_queues[protocol_class]
            if not hasattr(self, '_multiplex_token'):
                raise Exception("Multiplexer is not multiplexed")
            token = self._multiplex_token

            while not self.is_closing and not token.triggered:
                try:
                    # We use an optimistic strategy here of using
                    # `get_nowait()` to reduce the number of times we yield to
                    # the event loop.  Since this is an async generator it will
                    # yield to the loop each time it returns a value so we
                    # don't have to worry about this blocking other processes.
                    yield msg_queue.get_nowait()
                except asyncio.QueueEmpty:
                    yield await self.wait(msg_queue.get(), token=token)

    #
    # Message reading and streaming API
    #
    @asynccontextmanager
    async def multiplex(self) -> AsyncIterator[None]:
        """
        API for running the background task that feeds individual protocol
        queues that allows each individual protocol to stream only its own
        messages.
        """
        # We generate a new token for each time the multiplexer is used to
        # multiplex so that we can reliably cancel it without requiring the
        # master token for the multiplexer to be cancelled.
        async with self._multiplex_lock:
            multiplex_token = CancelToken(
                'multiplex',
                loop=self.cancel_token.loop,
            ).chain(self.cancel_token)
            self._multiplex_token = multiplex_token
            fut = asyncio.ensure_future(self._do_multiplexing(multiplex_token))
            try:
                yield
            finally:
                multiplex_token.trigger()
                del self._multiplex_token
                if fut.done():
                    fut.result()
                else:
                    fut.cancel()
                    try:
                        await fut
                    except asyncio.CancelledError:
                        pass

    async def _do_multiplexing(self, token: CancelToken) -> None:
        """
        Background task that reads messages from the transport and feeds them
        into individual queues for each of the protocols.
        """
        msg_stream = self.wait_iter(stream_transport_messages(
            self._transport,
            self._base_protocol,
            *self._protocols,
            token=token,
        ),
                                    token=token)
        async for protocol, cmd, msg in msg_stream:
            # track total number of messages received for each command type.
            self._msg_counts[type(cmd)] += 1

            queue = self._protocol_queues[type(protocol)]
            try:
                # We must use `put_nowait` here to ensure that in the event
                # that a single protocol queue is full that we don't block
                # other protocol messages getting through.
                queue.put_nowait((cmd, msg))
            except asyncio.QueueFull:
                self.logger.error(
                    ("Multiplexing queue for protocol '%s' full. "
                     "discarding message: %s"),
                    protocol,
                    cmd,
                )